!2286 enable optimizer parallel with broadcast

Merge pull request !2286 from gziyan/optimizer_parallel
This commit is contained in:
mindspore-ci-bot 2020-06-27 15:45:56 +08:00 committed by Gitee
commit ca1d2436b8
14 changed files with 354 additions and 124 deletions

View File

@ -62,6 +62,7 @@ void ParallelContext::Reset() {
enable_all_reduce_fusion_ = false; enable_all_reduce_fusion_ = false;
strategy_ckpt_load_file_ = ""; strategy_ckpt_load_file_ = "";
strategy_ckpt_save_file_ = ""; strategy_ckpt_save_file_ = "";
enable_parallel_optimizer_ = false;
} }
void ParallelContext::set_device_num(int32_t device_num) { void ParallelContext::set_device_num(int32_t device_num) {

View File

@ -100,6 +100,11 @@ class ParallelContext {
void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file); void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file);
std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; } std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; }
void set_enable_parallel_optimizer(bool enable_parallel_optimizer) {
enable_parallel_optimizer_ = enable_parallel_optimizer;
}
bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
void Reset(); void Reset();
private: private:
@ -123,6 +128,7 @@ class ParallelContext {
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_; std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
std::string strategy_ckpt_load_file_; std::string strategy_ckpt_load_file_;
std::string strategy_ckpt_save_file_; std::string strategy_ckpt_save_file_;
bool enable_parallel_optimizer_;
}; };
void ParallelParameterContextInit(const FuncGraphPtr &func_graph); void ParallelParameterContextInit(const FuncGraphPtr &func_graph);

View File

@ -205,6 +205,10 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
.def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.") .def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.")
.def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.") .def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.")
.def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer,
"Set enable/disable parallel optimizer.")
.def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer,
"Get enable/disable parallel optimizer.")
.def("reset", &ParallelContext::Reset, "Reset auto parallel context."); .def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext") (void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")

View File

@ -29,8 +29,9 @@ from .optimizer import Optimizer
_adam_opt = C.MultitypeFuncGraph("adam_opt") _adam_opt = C.MultitypeFuncGraph("adam_opt")
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): "Tensor", "Bool", "Bool")
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag, optim_filter):
""" """
Update parameters. Update parameters.
@ -44,38 +45,44 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
m (Tensor): m value of parameters. m (Tensor): m value of parameters.
v (Tensor): v value of parameters. v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters. gradient (Tensor): Gradient of parameters.
decay_flag (bool): Applies weight decay or not.
optim_filter (bool): Applies parameter update or not.
Returns: Returns:
Tensor, the new value of v after updating. Tensor, the new value of v after updating.
""" """
op_mul = P.Mul() if optim_filter:
op_square = P.Square() op_mul = P.Mul()
op_sqrt = P.Sqrt() op_square = P.Square()
op_cast = P.Cast() op_sqrt = P.Sqrt()
op_reshape = P.Reshape() op_cast = P.Cast()
op_shape = P.Shape() op_reshape = P.Reshape()
op_shape = P.Shape()
param_fp32 = op_cast(param, mstype.float32) param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32) m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32) v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32) gradient_fp32 = op_cast(gradient, mstype.float32)
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta1, gradient_fp32)
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta2, op_square(gradient_fp32)) - beta2, op_square(gradient_fp32))
update = next_m / (eps + op_sqrt(next_v))
if decay_flag:
update = op_mul(weight_decay_tensor, param_fp32) + update
update_with_lr = op_mul(lr, update) update = next_m / (eps + op_sqrt(next_v))
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) if decay_flag:
update = op_mul(weight_decay_tensor, param_fp32) + update
next_v = F.depend(next_v, F.assign(param, op_cast(next_param, F.dtype(param)))) update_with_lr = op_mul(lr, update)
next_v = F.depend(next_v, F.assign(m, op_cast(next_m, F.dtype(m)))) next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
next_v = F.depend(next_v, F.assign(v, op_cast(next_v, F.dtype(v))))
return next_v next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
return next_param
return gradient
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
@ -300,7 +307,7 @@ class AdamWeightDecay(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs: Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`. tuple[bool], all elements are True.
Examples: Examples:
>>> net = Net() >>> net = Net()
@ -328,11 +335,13 @@ class AdamWeightDecay(Optimizer):
def construct(self, gradients): def construct(self, gradients):
lr = self.get_lr() lr = self.get_lr()
updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor), self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients, self.decay_flag) self.params, self.moments1, self.moments2, gradients,
self.decay_flag, self.optim_filter)
return updated_velocity if self.use_parallel:
optim_result = self.broadcast_params(optim_result)
return optim_result
class AdamWeightDecayDynamicLR(Optimizer): class AdamWeightDecayDynamicLR(Optimizer):
@ -363,7 +372,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs: Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`. tuple[bool], all elements are True.
Examples: Examples:
>>> net = Net() >>> net = Net()
@ -424,12 +433,14 @@ class AdamWeightDecayDynamicLR(Optimizer):
warmup_lr = self.start_learning_rate * warmup_percent warmup_lr = self.start_learning_rate * warmup_percent
is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32) is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32)
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor), self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients, self.decay_flag) self.params, self.moments1, self.moments2, gradients,
self.decay_flag, self.optim_filter)
if self.use_parallel:
optim_result = self.broadcast_params(optim_result)
added_global_step = self.global_step + self.one added_global_step = self.global_step + self.one
F.control_depend(lr, added_global_step) F.control_depend(lr, added_global_step)
self.global_step = added_global_step self.global_step = added_global_step
return updated_velocity return optim_result

View File

@ -32,11 +32,10 @@ num_one = Tensor(np.ones([1]), mstype.float32)
_lamb_opt = C.MultitypeFuncGraph("lamb_opt") _lamb_opt = C.MultitypeFuncGraph("lamb_opt")
@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
"Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v, def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v,
gradient, decay_flag): gradient, decay_flag, optim_filter):
""" """
Update parameters. Update parameters.
@ -52,66 +51,66 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
v (Tensor): v value of parameters. v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters. gradient (Tensor): Gradient of parameters.
decay_flag (bool): Specifies whether param update with weight decay. decay_flag (bool): Specifies whether param update with weight decay.
optim_filter(bool): Applies parameter update or not.
Returns: Returns:
Tensor, the new value of v after updating. Tensor, the new value of v after updating.
""" """
op_mul = P.Mul() if optim_filter:
op_sqrt = P.Sqrt() op_mul = P.Mul()
op_rsqrt = P.Rsqrt() op_sqrt = P.Sqrt()
op_square = P.Square() op_rsqrt = P.Rsqrt()
op_cast = P.Cast() op_square = P.Square()
op_reshape = P.Reshape() op_cast = P.Cast()
op_shape = P.Shape() op_reshape = P.Reshape()
op_pow = P.Pow() op_shape = P.Shape()
op_norm = layer.Norm() op_pow = P.Pow()
op_select = P.Select() op_norm = layer.Norm()
op_greater = P.Greater() op_select = P.Select()
op_fill = P.Fill() op_greater = P.Greater()
op_dtype = P.DType() op_fill = P.Fill()
op_dtype = P.DType()
param_fp32 = op_cast(param, mstype.float32) param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32) m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32) v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32) gradient_fp32 = op_cast(gradient, mstype.float32)
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32)
mstype.float32) - beta1, gradient_fp32)
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32))
mstype.float32) - beta2, op_square(gradient_fp32))
next_mm = next_m / (op_cast(num_one, mstype.float32) next_mm = next_m / (op_cast(num_one, mstype.float32)
- op_pow(beta1, op_cast(global_step + num_one, mstype.float32))) - op_pow(beta1, op_cast(global_step + num_one, mstype.float32)))
next_vv = next_v / (op_cast(num_one, mstype.float32) - next_vv = next_v / (op_cast(num_one, mstype.float32) -
op_pow(beta2, op_cast(global_step + num_one, mstype.float32))) op_pow(beta2, op_cast(global_step + num_one, mstype.float32)))
w_norm = op_norm(param_fp32) w_norm = op_norm(param_fp32)
g_norm = op_norm(gradient_fp32) g_norm = op_norm(gradient_fp32)
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt( g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32)
next_vv + eps)) + weight_decay_tensor * param_fp32) zeros = F.zeros_like(w_norm)
zeros = F.zeros_like(w_norm) ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) trust_ratio = op_select(
trust_ratio = op_select( op_greater(w_norm, zeros),
op_greater(w_norm, zeros), op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones), ones)
ones) tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0) trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
trust_ratio = C.clip_by_value(trust_ratio, zeros, tens) update = next_mm / (op_sqrt(next_vv) + eps)
update = next_mm / (op_sqrt(next_vv) + eps)
if decay_flag: if decay_flag:
update = update + op_mul(weight_decay_tensor, param_fp32) update = update + op_mul(weight_decay_tensor, param_fp32)
update_with_lr = op_mul(op_mul(trust_ratio, lr), update) update_with_lr = op_mul(op_mul(trust_ratio, lr), update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
next_v = F.depend(next_v, F.assign(param, next_param)) next_param = F.depend(next_param, F.assign(param, next_param))
next_v = F.depend(next_v, F.assign(m, next_m)) next_param = F.depend(next_param, F.assign(m, next_m))
next_v = F.depend(next_v, F.assign(v, next_v)) next_param = F.depend(next_param, F.assign(v, next_v))
return next_v return next_param
return gradient
lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel") lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel")
@ -238,7 +237,7 @@ class Lamb(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs: Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`. tuple[bool], all elements are True.
Examples: Examples:
>>> net = Net() >>> net = Net()
@ -311,18 +310,21 @@ class Lamb(Optimizer):
self.warmup_steps, self.global_step), mstype.float32) self.warmup_steps, self.global_step), mstype.float32)
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
if self.enable_graph_kernel: if self.enable_graph_kernel:
updated_velocity = self.hyper_map(F.partial(lamb_opt_graph_kernel, optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel,
self.beta1, self.beta2, self.eps, lr, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor, self.global_step), self.weight_decay_tensor, self.global_step),
self.params, self.moments1, self.moments2, gradients, self.decay_flag) self.params, self.moments1, self.moments2, gradients, self.decay_flag)
else: else:
updated_velocity = self.hyper_map(F.partial(_lamb_opt, optim_result = self.hyper_map(F.partial(_lamb_opt,
self.beta1, self.beta2, self.eps, lr, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor, self.global_step), self.weight_decay_tensor, self.global_step),
self.params, self.moments1, self.moments2, gradients, self.decay_flag) self.params, self.moments1, self.moments2, gradients,
self.decay_flag, self.optim_filter)
if self.use_parallel:
optim_result = self.broadcast_params(optim_result)
added_global_step = self.global_step + self.one added_global_step = self.global_step + self.one
F.control_depend(lr, added_global_step) F.control_depend(lr, added_global_step)
self.global_step = added_global_step self.global_step = added_global_step
return updated_velocity return optim_result

View File

@ -22,11 +22,14 @@ from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from mindspore.common.tensor import Tensor
from mindspore import log as logger from mindspore import log as logger
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.parallel_utils import ParallelMode
__all__ = ['Optimizer'] __all__ = ['Optimizer']
@ -155,6 +158,27 @@ class Optimizer(Cell):
self.param_length = len(self.parameters) self.param_length = len(self.parameters)
self.map_ = C.Map() self.map_ = C.Map()
use_parallel = auto_parallel_context().get_enable_parallel_optimizer()
self.use_parallel = use_parallel
if use_parallel:
if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]:
raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name))
if _get_parallel_mode() not in [ParallelMode.HYBRID_PARALLEL, ParallelMode.DATA_PARALLEL,
ParallelMode.AUTO_PARALLEL]:
raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format
(_get_parallel_mode()))
self.dev_num = _get_device_num()
if self.dev_num > self.param_length:
raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is"
" less than the number of devices {}".format(self.param_length, self.dev_num))
self.param_rank = self._get_parameter_group_id()
self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank))
self.param_names = []
for param in self.parameters:
self.param_names.append(param.name)
else:
self.optim_filter = (True,) * self.param_length
def decay_weight(self, gradients): def decay_weight(self, gradients):
""" """
Weight decay. Weight decay.
@ -401,6 +425,51 @@ class Optimizer(Cell):
lr = self.learning_rate lr = self.learning_rate
return lr return lr
def _get_parameter_group_id(self):
"""
Get the parameter partition group id, which is less than the number of devices.
Returns:
tuple, the group id tuple of parameters.
"""
rank_list = ()
count = 0
for _ in range(self.param_length):
rank_list = rank_list + (count,)
count = count + 1
if count == self.dev_num:
count = 0
return rank_list
def broadcast_params(self, optim_result):
"""
Apply Broadcast operations in the sequential order of parameter groups.
Returns:
bool, the status flag.
"""
param_group = []
key_group = []
for _ in range(self.dev_num):
param_group.append(F.make_tuple())
key_group.append(F.make_tuple())
for i in range(self.param_length):
param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],)
key = P.MakeRefKey(self.param_names[i])()
key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,)
new_param_group = []
for root in range(self.dev_num):
ops = P.Broadcast(root)
next_params = ops(param_group[root])
new_param_group.append(next_params)
for i in range(F.tuple_len(next_params)):
F.assign(key_group[root][i], next_params[i])
status = True
for i in range(self.dev_num - 1):
status = F.control_depend(new_param_group[i][0], new_param_group[i+1])
return status
def construct(self, *hyper_params): def construct(self, *hyper_params):
raise NotImplementedError raise NotImplementedError

View File

@ -85,6 +85,22 @@ def _list_setitem_with_List(data, number_index, value):
return F.list_setitem(data, number_index, value) return F.list_setitem(data, number_index, value)
@setitem.register("List", "Number", "Tuple")
def _list_setitem_with_Tuple(data, number_index, value):
"""
Assigns value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (list): Value given.
Outputs:
list, type is same as the element type of data.
"""
return F.list_setitem(data, number_index, value)
@setitem.register("Dictionary", "String", "Tensor") @setitem.register("Dictionary", "String", "Tensor")
def _dict_setitem_with_tensor(data, key, value): def _dict_setitem_with_tensor(data, key, value):
""" """

View File

@ -98,6 +98,7 @@ class AllReduce(PrimitiveWithInfer):
self.op = op self.op = op
self.add_prim_attr('group', _get_group(group)) self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 0) self.add_prim_attr('fusion', 0)
self.add_prim_attr('index', 0)
def vm_impl(self, x): def vm_impl(self, x):
"""Implement by vm mode.""" """Implement by vm mode."""

View File

@ -59,8 +59,7 @@ class Assign(PrimitiveWithInfer):
return variable return variable
def infer_dtype(self, variable, value): def infer_dtype(self, variable, value):
args = {"variable": variable, "value": value} # Add a type validation later when we don't have to assign a value to RefKey.
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return variable return variable

View File

@ -400,6 +400,23 @@ class _AutoParallelContext:
self.check_context_handle() self.check_context_handle()
return self._context_handle.get_global_rank_is_set() return self._context_handle.get_global_rank_is_set()
def set_enable_parallel_optimizer(self, enable_parallel_optimizer):
"""
Set enable/disable parallel optimizer.
Args:
set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer.
"""
self.check_context_handle()
if not isinstance(enable_parallel_optimizer, bool):
raise TypeError('enable_parallel_optimizer is invalid type')
self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
def get_enable_parallel_optimizer(self):
"""Get parallel optimizer flag."""
self.check_context_handle()
return self._context_handle.get_enable_parallel_optimizer()
def reset(self): def reset(self):
"""Reset all settings.""" """Reset all settings."""
self.check_context_handle() self.check_context_handle()
@ -433,7 +450,8 @@ _set_auto_parallel_context_func_map = {
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast, "parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
"full_batch": auto_parallel_context().set_full_batch} "full_batch": auto_parallel_context().set_full_batch,
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer}
_get_auto_parallel_context_func_map = { _get_auto_parallel_context_func_map = {
@ -447,13 +465,15 @@ _get_auto_parallel_context_func_map = {
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast, "parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file, "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
"full_batch": auto_parallel_context().get_full_batch} "full_batch": auto_parallel_context().get_full_batch,
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer}
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool) strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
def _set_auto_parallel_context(**kwargs): def _set_auto_parallel_context(**kwargs):
""" """
Set auto parallel context. Set auto parallel context.
@ -493,6 +513,7 @@ def _set_auto_parallel_context(**kwargs):
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False. full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer (bool): Enable using optimizer segmentation or noe. Default: False.
Raises: Raises:
ValueError: If input key is not attribute in auto parallel context. ValueError: If input key is not attribute in auto parallel context.
@ -535,5 +556,6 @@ def _reset_auto_parallel_context():
- parameter_broadcast: False. - parameter_broadcast: False.
- strategy_ckpt_load_file: "" - strategy_ckpt_load_file: ""
- strategy_ckpt_save_file: "" - strategy_ckpt_save_file: ""
- enable_parallel_optimizer: False
""" """
auto_parallel_context().reset() auto_parallel_context().reset()

View File

@ -327,6 +327,9 @@ TEST_F(TestStepParallel, CreatOpInstance) {
} else if (name == "instance_name") { } else if (name == "instance_name") {
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret); parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
ASSERT_EQ(converted_ret->ToString(), "test"); ASSERT_EQ(converted_ret->ToString(), "test");
} else if (name == "index") {
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
ASSERT_EQ(converted_ret->ToString(), "0");
} else { } else {
MS_LOG(EXCEPTION) << "Test failed"; MS_LOG(EXCEPTION) << "Test failed";
} }

View File

@ -16,7 +16,6 @@
test assign sub test assign sub
""" """
import numpy as np import numpy as np
import pytest
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
@ -36,27 +35,6 @@ class AssignW(nn.Cell):
return x return x
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.b = Parameter(initializer('ones', [5]), name='b')
self.assign = AssignW()
def construct(self, value):
return self.assign(self.b, value)
def test_assign_through_cell():
context.set_context(mode=context.GRAPH_MODE)
net = Net()
net.to_float(ms.float16)
net.add_flags_recursive(fp16=False)
input_data = Tensor(np.ones([5]).astype(np.float32))
net(input_data)
with pytest.raises(TypeError):
net(None)
class AssignOp(nn.Cell): class AssignOp(nn.Cell):
def __init__(self): def __init__(self):
super(AssignOp, self).__init__() super(AssignOp, self).__init__()

View File

@ -0,0 +1,114 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test adam """
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR, Lamb
from mindspore.ops import operations as P
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore import context
class Net(nn.Cell):
"""Net definition"""
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Dense(128, 768, activation='relu')
self.fc2 = nn.Dense(128, 768, activation='relu')
self.fc3 = nn.Dense(128, 768, activation='relu')
self.fc4 = nn.Dense(768, 768, activation='relu')
self.relu4 = nn.ReLU()
self.relu5 = nn.ReLU()
self.transpose = P.Transpose()
self.matmul1 = P.MatMul()
self.matmul2 = P.MatMul()
def construct(self, x):
q = self.fc1(x)
k = self.fc2(x)
v = self.fc3(x)
k = self.transpose(k, (1, 0))
c = self.relu4(self.matmul1(q, k))
s = self.relu5(self.matmul2(c, v))
s = self.fc4(s)
return s
def test_AdamWeightDecayDynamicLR():
""" test_AdamWeightDecayDynamicLR """
auto_parallel_context().set_enable_parallel_optimizer(True)
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2)
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = AdamWeightDecayDynamicLR(net.trainable_params(), decay_steps=20, learning_rate=0.1)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_AdamWeightDecay():
""" test_AdamWeightDecayDynamicLR """
auto_parallel_context().set_enable_parallel_optimizer(True)
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2)
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = AdamWeightDecay(net.trainable_params(), learning_rate=0.1)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_lamb_compile():
""" test_Lamb_compile """
auto_parallel_context().set_enable_parallel_optimizer(True)
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2)
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Lamb(net.trainable_params(), decay_steps=10)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_edge_case():
""" test_edge_case """
auto_parallel_context().set_enable_parallel_optimizer(True)
net = Net()
with pytest.raises(RuntimeError):
context.set_auto_parallel_context(parallel_mode="stand_alone")
Lamb(net.trainable_params(), decay_steps=10)
with pytest.raises(RuntimeError):
Adam(net.trainable_params(), learning_rate=0.1)
with pytest.raises(RuntimeError):
context.set_auto_parallel_context(device_num=16)
Lamb(net.trainable_params(), decay_steps=10)

View File

@ -81,6 +81,10 @@ def test_set_auto_parallel_context():
with pytest.raises(ValueError): with pytest.raises(ValueError):
set_algo_parameters(tensor_slice_align_size=1025) set_algo_parameters(tensor_slice_align_size=1025)
auto_parallel_context().set_enable_parallel_optimizer(True)
assert auto_parallel_context().get_enable_parallel_optimizer() is True
assert not auto_parallel_context().get_all_reduce_fusion_split_indices()
def test_reset_auto_parallel_context(): def test_reset_auto_parallel_context():
context.reset_auto_parallel_context() context.reset_auto_parallel_context()