forked from mindspore-Ecosystem/mindspore
!2286 enable optimizer parallel with broadcast
Merge pull request !2286 from gziyan/optimizer_parallel
This commit is contained in:
commit
ca1d2436b8
|
@ -62,6 +62,7 @@ void ParallelContext::Reset() {
|
||||||
enable_all_reduce_fusion_ = false;
|
enable_all_reduce_fusion_ = false;
|
||||||
strategy_ckpt_load_file_ = "";
|
strategy_ckpt_load_file_ = "";
|
||||||
strategy_ckpt_save_file_ = "";
|
strategy_ckpt_save_file_ = "";
|
||||||
|
enable_parallel_optimizer_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ParallelContext::set_device_num(int32_t device_num) {
|
void ParallelContext::set_device_num(int32_t device_num) {
|
||||||
|
|
|
@ -100,6 +100,11 @@ class ParallelContext {
|
||||||
void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file);
|
void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file);
|
||||||
std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; }
|
std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; }
|
||||||
|
|
||||||
|
void set_enable_parallel_optimizer(bool enable_parallel_optimizer) {
|
||||||
|
enable_parallel_optimizer_ = enable_parallel_optimizer;
|
||||||
|
}
|
||||||
|
bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
|
||||||
|
|
||||||
void Reset();
|
void Reset();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -123,6 +128,7 @@ class ParallelContext {
|
||||||
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
|
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
|
||||||
std::string strategy_ckpt_load_file_;
|
std::string strategy_ckpt_load_file_;
|
||||||
std::string strategy_ckpt_save_file_;
|
std::string strategy_ckpt_save_file_;
|
||||||
|
bool enable_parallel_optimizer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
void ParallelParameterContextInit(const FuncGraphPtr &func_graph);
|
void ParallelParameterContextInit(const FuncGraphPtr &func_graph);
|
||||||
|
|
|
@ -205,6 +205,10 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
|
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
|
||||||
.def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.")
|
.def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.")
|
||||||
.def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.")
|
.def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.")
|
||||||
|
.def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer,
|
||||||
|
"Set enable/disable parallel optimizer.")
|
||||||
|
.def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer,
|
||||||
|
"Get enable/disable parallel optimizer.")
|
||||||
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
|
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
|
||||||
|
|
||||||
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
|
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
|
||||||
|
|
|
@ -29,8 +29,9 @@ from .optimizer import Optimizer
|
||||||
_adam_opt = C.MultitypeFuncGraph("adam_opt")
|
_adam_opt = C.MultitypeFuncGraph("adam_opt")
|
||||||
|
|
||||||
|
|
||||||
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
|
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||||
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
|
"Tensor", "Bool", "Bool")
|
||||||
|
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag, optim_filter):
|
||||||
"""
|
"""
|
||||||
Update parameters.
|
Update parameters.
|
||||||
|
|
||||||
|
@ -44,38 +45,44 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
|
||||||
m (Tensor): m value of parameters.
|
m (Tensor): m value of parameters.
|
||||||
v (Tensor): v value of parameters.
|
v (Tensor): v value of parameters.
|
||||||
gradient (Tensor): Gradient of parameters.
|
gradient (Tensor): Gradient of parameters.
|
||||||
|
decay_flag (bool): Applies weight decay or not.
|
||||||
|
optim_filter (bool): Applies parameter update or not.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, the new value of v after updating.
|
Tensor, the new value of v after updating.
|
||||||
"""
|
"""
|
||||||
op_mul = P.Mul()
|
if optim_filter:
|
||||||
op_square = P.Square()
|
op_mul = P.Mul()
|
||||||
op_sqrt = P.Sqrt()
|
op_square = P.Square()
|
||||||
op_cast = P.Cast()
|
op_sqrt = P.Sqrt()
|
||||||
op_reshape = P.Reshape()
|
op_cast = P.Cast()
|
||||||
op_shape = P.Shape()
|
op_reshape = P.Reshape()
|
||||||
|
op_shape = P.Shape()
|
||||||
|
|
||||||
param_fp32 = op_cast(param, mstype.float32)
|
param_fp32 = op_cast(param, mstype.float32)
|
||||||
m_fp32 = op_cast(m, mstype.float32)
|
m_fp32 = op_cast(m, mstype.float32)
|
||||||
v_fp32 = op_cast(v, mstype.float32)
|
v_fp32 = op_cast(v, mstype.float32)
|
||||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||||
|
|
||||||
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
|
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
|
||||||
|
- beta1, gradient_fp32)
|
||||||
|
|
||||||
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
|
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
|
||||||
- beta2, op_square(gradient_fp32))
|
- beta2, op_square(gradient_fp32))
|
||||||
|
|
||||||
update = next_m / (eps + op_sqrt(next_v))
|
|
||||||
if decay_flag:
|
|
||||||
update = op_mul(weight_decay_tensor, param_fp32) + update
|
|
||||||
|
|
||||||
update_with_lr = op_mul(lr, update)
|
update = next_m / (eps + op_sqrt(next_v))
|
||||||
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
if decay_flag:
|
||||||
|
update = op_mul(weight_decay_tensor, param_fp32) + update
|
||||||
|
|
||||||
next_v = F.depend(next_v, F.assign(param, op_cast(next_param, F.dtype(param))))
|
update_with_lr = op_mul(lr, update)
|
||||||
next_v = F.depend(next_v, F.assign(m, op_cast(next_m, F.dtype(m))))
|
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
||||||
next_v = F.depend(next_v, F.assign(v, op_cast(next_v, F.dtype(v))))
|
|
||||||
return next_v
|
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
|
||||||
|
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
|
||||||
|
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
|
||||||
|
return next_param
|
||||||
|
return gradient
|
||||||
|
|
||||||
|
|
||||||
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
|
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
|
||||||
|
@ -300,7 +307,7 @@ class AdamWeightDecay(Optimizer):
|
||||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
|
tuple[bool], all elements are True.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> net = Net()
|
>>> net = Net()
|
||||||
|
@ -328,11 +335,13 @@ class AdamWeightDecay(Optimizer):
|
||||||
|
|
||||||
def construct(self, gradients):
|
def construct(self, gradients):
|
||||||
lr = self.get_lr()
|
lr = self.get_lr()
|
||||||
updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
|
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
|
||||||
self.weight_decay_tensor),
|
self.weight_decay_tensor),
|
||||||
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
self.params, self.moments1, self.moments2, gradients,
|
||||||
|
self.decay_flag, self.optim_filter)
|
||||||
return updated_velocity
|
if self.use_parallel:
|
||||||
|
optim_result = self.broadcast_params(optim_result)
|
||||||
|
return optim_result
|
||||||
|
|
||||||
|
|
||||||
class AdamWeightDecayDynamicLR(Optimizer):
|
class AdamWeightDecayDynamicLR(Optimizer):
|
||||||
|
@ -363,7 +372,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
|
||||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
|
tuple[bool], all elements are True.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> net = Net()
|
>>> net = Net()
|
||||||
|
@ -424,12 +433,14 @@ class AdamWeightDecayDynamicLR(Optimizer):
|
||||||
warmup_lr = self.start_learning_rate * warmup_percent
|
warmup_lr = self.start_learning_rate * warmup_percent
|
||||||
is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32)
|
is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32)
|
||||||
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
|
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
|
||||||
updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
|
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
|
||||||
self.weight_decay_tensor),
|
self.weight_decay_tensor),
|
||||||
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
self.params, self.moments1, self.moments2, gradients,
|
||||||
|
self.decay_flag, self.optim_filter)
|
||||||
|
if self.use_parallel:
|
||||||
|
optim_result = self.broadcast_params(optim_result)
|
||||||
added_global_step = self.global_step + self.one
|
added_global_step = self.global_step + self.one
|
||||||
F.control_depend(lr, added_global_step)
|
F.control_depend(lr, added_global_step)
|
||||||
self.global_step = added_global_step
|
self.global_step = added_global_step
|
||||||
|
|
||||||
return updated_velocity
|
return optim_result
|
||||||
|
|
|
@ -32,11 +32,10 @@ num_one = Tensor(np.ones([1]), mstype.float32)
|
||||||
|
|
||||||
_lamb_opt = C.MultitypeFuncGraph("lamb_opt")
|
_lamb_opt = C.MultitypeFuncGraph("lamb_opt")
|
||||||
|
|
||||||
|
@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||||
@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
"Tensor", "Bool", "Bool")
|
||||||
"Tensor", "Tensor", "Tensor", "Tensor", "Bool")
|
|
||||||
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v,
|
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v,
|
||||||
gradient, decay_flag):
|
gradient, decay_flag, optim_filter):
|
||||||
"""
|
"""
|
||||||
Update parameters.
|
Update parameters.
|
||||||
|
|
||||||
|
@ -52,66 +51,66 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
|
||||||
v (Tensor): v value of parameters.
|
v (Tensor): v value of parameters.
|
||||||
gradient (Tensor): Gradient of parameters.
|
gradient (Tensor): Gradient of parameters.
|
||||||
decay_flag (bool): Specifies whether param update with weight decay.
|
decay_flag (bool): Specifies whether param update with weight decay.
|
||||||
|
optim_filter(bool): Applies parameter update or not.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, the new value of v after updating.
|
Tensor, the new value of v after updating.
|
||||||
"""
|
"""
|
||||||
op_mul = P.Mul()
|
if optim_filter:
|
||||||
op_sqrt = P.Sqrt()
|
op_mul = P.Mul()
|
||||||
op_rsqrt = P.Rsqrt()
|
op_sqrt = P.Sqrt()
|
||||||
op_square = P.Square()
|
op_rsqrt = P.Rsqrt()
|
||||||
op_cast = P.Cast()
|
op_square = P.Square()
|
||||||
op_reshape = P.Reshape()
|
op_cast = P.Cast()
|
||||||
op_shape = P.Shape()
|
op_reshape = P.Reshape()
|
||||||
op_pow = P.Pow()
|
op_shape = P.Shape()
|
||||||
op_norm = layer.Norm()
|
op_pow = P.Pow()
|
||||||
op_select = P.Select()
|
op_norm = layer.Norm()
|
||||||
op_greater = P.Greater()
|
op_select = P.Select()
|
||||||
op_fill = P.Fill()
|
op_greater = P.Greater()
|
||||||
op_dtype = P.DType()
|
op_fill = P.Fill()
|
||||||
|
op_dtype = P.DType()
|
||||||
|
|
||||||
param_fp32 = op_cast(param, mstype.float32)
|
param_fp32 = op_cast(param, mstype.float32)
|
||||||
m_fp32 = op_cast(m, mstype.float32)
|
m_fp32 = op_cast(m, mstype.float32)
|
||||||
v_fp32 = op_cast(v, mstype.float32)
|
v_fp32 = op_cast(v, mstype.float32)
|
||||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||||
|
|
||||||
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one,
|
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32)
|
||||||
mstype.float32) - beta1, gradient_fp32)
|
|
||||||
|
|
||||||
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one,
|
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32))
|
||||||
mstype.float32) - beta2, op_square(gradient_fp32))
|
|
||||||
|
|
||||||
next_mm = next_m / (op_cast(num_one, mstype.float32)
|
next_mm = next_m / (op_cast(num_one, mstype.float32)
|
||||||
- op_pow(beta1, op_cast(global_step + num_one, mstype.float32)))
|
- op_pow(beta1, op_cast(global_step + num_one, mstype.float32)))
|
||||||
next_vv = next_v / (op_cast(num_one, mstype.float32) -
|
next_vv = next_v / (op_cast(num_one, mstype.float32) -
|
||||||
op_pow(beta2, op_cast(global_step + num_one, mstype.float32)))
|
op_pow(beta2, op_cast(global_step + num_one, mstype.float32)))
|
||||||
w_norm = op_norm(param_fp32)
|
w_norm = op_norm(param_fp32)
|
||||||
g_norm = op_norm(gradient_fp32)
|
g_norm = op_norm(gradient_fp32)
|
||||||
|
|
||||||
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(
|
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32)
|
||||||
next_vv + eps)) + weight_decay_tensor * param_fp32)
|
zeros = F.zeros_like(w_norm)
|
||||||
zeros = F.zeros_like(w_norm)
|
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
|
||||||
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
|
trust_ratio = op_select(
|
||||||
trust_ratio = op_select(
|
op_greater(w_norm, zeros),
|
||||||
op_greater(w_norm, zeros),
|
op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
|
||||||
op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
|
ones)
|
||||||
ones)
|
tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
|
||||||
tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
|
trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
|
||||||
trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
|
update = next_mm / (op_sqrt(next_vv) + eps)
|
||||||
update = next_mm / (op_sqrt(next_vv) + eps)
|
|
||||||
|
|
||||||
if decay_flag:
|
if decay_flag:
|
||||||
update = update + op_mul(weight_decay_tensor, param_fp32)
|
update = update + op_mul(weight_decay_tensor, param_fp32)
|
||||||
|
|
||||||
update_with_lr = op_mul(op_mul(trust_ratio, lr), update)
|
update_with_lr = op_mul(op_mul(trust_ratio, lr), update)
|
||||||
|
|
||||||
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
||||||
|
|
||||||
next_v = F.depend(next_v, F.assign(param, next_param))
|
next_param = F.depend(next_param, F.assign(param, next_param))
|
||||||
next_v = F.depend(next_v, F.assign(m, next_m))
|
next_param = F.depend(next_param, F.assign(m, next_m))
|
||||||
next_v = F.depend(next_v, F.assign(v, next_v))
|
next_param = F.depend(next_param, F.assign(v, next_v))
|
||||||
|
|
||||||
return next_v
|
return next_param
|
||||||
|
return gradient
|
||||||
|
|
||||||
|
|
||||||
lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel")
|
lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel")
|
||||||
|
@ -238,7 +237,7 @@ class Lamb(Optimizer):
|
||||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
|
tuple[bool], all elements are True.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> net = Net()
|
>>> net = Net()
|
||||||
|
@ -311,18 +310,21 @@ class Lamb(Optimizer):
|
||||||
self.warmup_steps, self.global_step), mstype.float32)
|
self.warmup_steps, self.global_step), mstype.float32)
|
||||||
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
|
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
|
||||||
if self.enable_graph_kernel:
|
if self.enable_graph_kernel:
|
||||||
updated_velocity = self.hyper_map(F.partial(lamb_opt_graph_kernel,
|
optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel,
|
||||||
self.beta1, self.beta2, self.eps, lr,
|
self.beta1, self.beta2, self.eps, lr,
|
||||||
self.weight_decay_tensor, self.global_step),
|
self.weight_decay_tensor, self.global_step),
|
||||||
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
||||||
else:
|
else:
|
||||||
updated_velocity = self.hyper_map(F.partial(_lamb_opt,
|
optim_result = self.hyper_map(F.partial(_lamb_opt,
|
||||||
self.beta1, self.beta2, self.eps, lr,
|
self.beta1, self.beta2, self.eps, lr,
|
||||||
self.weight_decay_tensor, self.global_step),
|
self.weight_decay_tensor, self.global_step),
|
||||||
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
self.params, self.moments1, self.moments2, gradients,
|
||||||
|
self.decay_flag, self.optim_filter)
|
||||||
|
if self.use_parallel:
|
||||||
|
optim_result = self.broadcast_params(optim_result)
|
||||||
|
|
||||||
added_global_step = self.global_step + self.one
|
added_global_step = self.global_step + self.one
|
||||||
F.control_depend(lr, added_global_step)
|
F.control_depend(lr, added_global_step)
|
||||||
self.global_step = added_global_step
|
self.global_step = added_global_step
|
||||||
|
|
||||||
return updated_velocity
|
return optim_result
|
||||||
|
|
|
@ -22,11 +22,14 @@ from mindspore.ops import functional as F, composite as C, operations as P
|
||||||
from mindspore.nn.cell import Cell
|
from mindspore.nn.cell import Cell
|
||||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
from mindspore._checkparam import Rel
|
from mindspore._checkparam import Rel
|
||||||
from mindspore.common.tensor import Tensor
|
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
|
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
|
||||||
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||||
|
from mindspore.train.parallel_utils import ParallelMode
|
||||||
|
|
||||||
__all__ = ['Optimizer']
|
__all__ = ['Optimizer']
|
||||||
|
|
||||||
|
@ -155,6 +158,27 @@ class Optimizer(Cell):
|
||||||
self.param_length = len(self.parameters)
|
self.param_length = len(self.parameters)
|
||||||
self.map_ = C.Map()
|
self.map_ = C.Map()
|
||||||
|
|
||||||
|
use_parallel = auto_parallel_context().get_enable_parallel_optimizer()
|
||||||
|
self.use_parallel = use_parallel
|
||||||
|
if use_parallel:
|
||||||
|
if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]:
|
||||||
|
raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name))
|
||||||
|
if _get_parallel_mode() not in [ParallelMode.HYBRID_PARALLEL, ParallelMode.DATA_PARALLEL,
|
||||||
|
ParallelMode.AUTO_PARALLEL]:
|
||||||
|
raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format
|
||||||
|
(_get_parallel_mode()))
|
||||||
|
self.dev_num = _get_device_num()
|
||||||
|
if self.dev_num > self.param_length:
|
||||||
|
raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is"
|
||||||
|
" less than the number of devices {}".format(self.param_length, self.dev_num))
|
||||||
|
self.param_rank = self._get_parameter_group_id()
|
||||||
|
self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank))
|
||||||
|
self.param_names = []
|
||||||
|
for param in self.parameters:
|
||||||
|
self.param_names.append(param.name)
|
||||||
|
else:
|
||||||
|
self.optim_filter = (True,) * self.param_length
|
||||||
|
|
||||||
def decay_weight(self, gradients):
|
def decay_weight(self, gradients):
|
||||||
"""
|
"""
|
||||||
Weight decay.
|
Weight decay.
|
||||||
|
@ -401,6 +425,51 @@ class Optimizer(Cell):
|
||||||
lr = self.learning_rate
|
lr = self.learning_rate
|
||||||
return lr
|
return lr
|
||||||
|
|
||||||
|
def _get_parameter_group_id(self):
|
||||||
|
"""
|
||||||
|
Get the parameter partition group id, which is less than the number of devices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple, the group id tuple of parameters.
|
||||||
|
"""
|
||||||
|
rank_list = ()
|
||||||
|
count = 0
|
||||||
|
for _ in range(self.param_length):
|
||||||
|
rank_list = rank_list + (count,)
|
||||||
|
count = count + 1
|
||||||
|
if count == self.dev_num:
|
||||||
|
count = 0
|
||||||
|
return rank_list
|
||||||
|
|
||||||
|
def broadcast_params(self, optim_result):
|
||||||
|
"""
|
||||||
|
Apply Broadcast operations in the sequential order of parameter groups.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, the status flag.
|
||||||
|
"""
|
||||||
|
param_group = []
|
||||||
|
key_group = []
|
||||||
|
for _ in range(self.dev_num):
|
||||||
|
param_group.append(F.make_tuple())
|
||||||
|
key_group.append(F.make_tuple())
|
||||||
|
for i in range(self.param_length):
|
||||||
|
param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],)
|
||||||
|
key = P.MakeRefKey(self.param_names[i])()
|
||||||
|
key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,)
|
||||||
|
new_param_group = []
|
||||||
|
for root in range(self.dev_num):
|
||||||
|
ops = P.Broadcast(root)
|
||||||
|
next_params = ops(param_group[root])
|
||||||
|
new_param_group.append(next_params)
|
||||||
|
for i in range(F.tuple_len(next_params)):
|
||||||
|
F.assign(key_group[root][i], next_params[i])
|
||||||
|
status = True
|
||||||
|
for i in range(self.dev_num - 1):
|
||||||
|
status = F.control_depend(new_param_group[i][0], new_param_group[i+1])
|
||||||
|
|
||||||
|
return status
|
||||||
|
|
||||||
def construct(self, *hyper_params):
|
def construct(self, *hyper_params):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -85,6 +85,22 @@ def _list_setitem_with_List(data, number_index, value):
|
||||||
return F.list_setitem(data, number_index, value)
|
return F.list_setitem(data, number_index, value)
|
||||||
|
|
||||||
|
|
||||||
|
@setitem.register("List", "Number", "Tuple")
|
||||||
|
def _list_setitem_with_Tuple(data, number_index, value):
|
||||||
|
"""
|
||||||
|
Assigns value to list.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (list): Data of type lis.
|
||||||
|
number_index (Number): Index of data.
|
||||||
|
value (list): Value given.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
list, type is same as the element type of data.
|
||||||
|
"""
|
||||||
|
return F.list_setitem(data, number_index, value)
|
||||||
|
|
||||||
|
|
||||||
@setitem.register("Dictionary", "String", "Tensor")
|
@setitem.register("Dictionary", "String", "Tensor")
|
||||||
def _dict_setitem_with_tensor(data, key, value):
|
def _dict_setitem_with_tensor(data, key, value):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -98,6 +98,7 @@ class AllReduce(PrimitiveWithInfer):
|
||||||
self.op = op
|
self.op = op
|
||||||
self.add_prim_attr('group', _get_group(group))
|
self.add_prim_attr('group', _get_group(group))
|
||||||
self.add_prim_attr('fusion', 0)
|
self.add_prim_attr('fusion', 0)
|
||||||
|
self.add_prim_attr('index', 0)
|
||||||
|
|
||||||
def vm_impl(self, x):
|
def vm_impl(self, x):
|
||||||
"""Implement by vm mode."""
|
"""Implement by vm mode."""
|
||||||
|
|
|
@ -59,8 +59,7 @@ class Assign(PrimitiveWithInfer):
|
||||||
return variable
|
return variable
|
||||||
|
|
||||||
def infer_dtype(self, variable, value):
|
def infer_dtype(self, variable, value):
|
||||||
args = {"variable": variable, "value": value}
|
# Add a type validation later when we don't have to assign a value to RefKey.
|
||||||
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
|
|
||||||
return variable
|
return variable
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -400,6 +400,23 @@ class _AutoParallelContext:
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
return self._context_handle.get_global_rank_is_set()
|
return self._context_handle.get_global_rank_is_set()
|
||||||
|
|
||||||
|
def set_enable_parallel_optimizer(self, enable_parallel_optimizer):
|
||||||
|
"""
|
||||||
|
Set enable/disable parallel optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer.
|
||||||
|
"""
|
||||||
|
self.check_context_handle()
|
||||||
|
if not isinstance(enable_parallel_optimizer, bool):
|
||||||
|
raise TypeError('enable_parallel_optimizer is invalid type')
|
||||||
|
self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
|
||||||
|
|
||||||
|
def get_enable_parallel_optimizer(self):
|
||||||
|
"""Get parallel optimizer flag."""
|
||||||
|
self.check_context_handle()
|
||||||
|
return self._context_handle.get_enable_parallel_optimizer()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset all settings."""
|
"""Reset all settings."""
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
|
@ -433,7 +450,8 @@ _set_auto_parallel_context_func_map = {
|
||||||
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
|
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
|
||||||
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
|
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
|
||||||
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
|
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
|
||||||
"full_batch": auto_parallel_context().set_full_batch}
|
"full_batch": auto_parallel_context().set_full_batch,
|
||||||
|
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer}
|
||||||
|
|
||||||
|
|
||||||
_get_auto_parallel_context_func_map = {
|
_get_auto_parallel_context_func_map = {
|
||||||
|
@ -447,13 +465,15 @@ _get_auto_parallel_context_func_map = {
|
||||||
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
|
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
|
||||||
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
|
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
|
||||||
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
|
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
|
||||||
"full_batch": auto_parallel_context().get_full_batch}
|
"full_batch": auto_parallel_context().get_full_batch,
|
||||||
|
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer}
|
||||||
|
|
||||||
|
|
||||||
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
|
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
|
||||||
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
|
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
|
||||||
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||||
strategy_ckpt_save_file=str, full_batch=bool)
|
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
|
||||||
|
|
||||||
def _set_auto_parallel_context(**kwargs):
|
def _set_auto_parallel_context(**kwargs):
|
||||||
"""
|
"""
|
||||||
Set auto parallel context.
|
Set auto parallel context.
|
||||||
|
@ -493,6 +513,7 @@ def _set_auto_parallel_context(**kwargs):
|
||||||
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
||||||
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
||||||
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
||||||
|
enable_parallel_optimizer (bool): Enable using optimizer segmentation or noe. Default: False.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If input key is not attribute in auto parallel context.
|
ValueError: If input key is not attribute in auto parallel context.
|
||||||
|
@ -535,5 +556,6 @@ def _reset_auto_parallel_context():
|
||||||
- parameter_broadcast: False.
|
- parameter_broadcast: False.
|
||||||
- strategy_ckpt_load_file: ""
|
- strategy_ckpt_load_file: ""
|
||||||
- strategy_ckpt_save_file: ""
|
- strategy_ckpt_save_file: ""
|
||||||
|
- enable_parallel_optimizer: False
|
||||||
"""
|
"""
|
||||||
auto_parallel_context().reset()
|
auto_parallel_context().reset()
|
||||||
|
|
|
@ -327,6 +327,9 @@ TEST_F(TestStepParallel, CreatOpInstance) {
|
||||||
} else if (name == "instance_name") {
|
} else if (name == "instance_name") {
|
||||||
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
|
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
|
||||||
ASSERT_EQ(converted_ret->ToString(), "test");
|
ASSERT_EQ(converted_ret->ToString(), "test");
|
||||||
|
} else if (name == "index") {
|
||||||
|
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
|
||||||
|
ASSERT_EQ(converted_ret->ToString(), "0");
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Test failed";
|
MS_LOG(EXCEPTION) << "Test failed";
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
test assign sub
|
test assign sub
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
|
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
|
@ -36,27 +35,6 @@ class AssignW(nn.Cell):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.b = Parameter(initializer('ones', [5]), name='b')
|
|
||||||
self.assign = AssignW()
|
|
||||||
|
|
||||||
def construct(self, value):
|
|
||||||
return self.assign(self.b, value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_assign_through_cell():
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
|
||||||
net = Net()
|
|
||||||
net.to_float(ms.float16)
|
|
||||||
net.add_flags_recursive(fp16=False)
|
|
||||||
input_data = Tensor(np.ones([5]).astype(np.float32))
|
|
||||||
net(input_data)
|
|
||||||
with pytest.raises(TypeError):
|
|
||||||
net(None)
|
|
||||||
|
|
||||||
|
|
||||||
class AssignOp(nn.Cell):
|
class AssignOp(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(AssignOp, self).__init__()
|
super(AssignOp, self).__init__()
|
||||||
|
|
|
@ -0,0 +1,114 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
""" test adam """
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common.api import _executor
|
||||||
|
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||||
|
from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR, Lamb
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||||
|
from mindspore import context
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
"""Net definition"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.fc1 = nn.Dense(128, 768, activation='relu')
|
||||||
|
self.fc2 = nn.Dense(128, 768, activation='relu')
|
||||||
|
self.fc3 = nn.Dense(128, 768, activation='relu')
|
||||||
|
self.fc4 = nn.Dense(768, 768, activation='relu')
|
||||||
|
self.relu4 = nn.ReLU()
|
||||||
|
self.relu5 = nn.ReLU()
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
self.matmul1 = P.MatMul()
|
||||||
|
self.matmul2 = P.MatMul()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
q = self.fc1(x)
|
||||||
|
k = self.fc2(x)
|
||||||
|
v = self.fc3(x)
|
||||||
|
k = self.transpose(k, (1, 0))
|
||||||
|
c = self.relu4(self.matmul1(q, k))
|
||||||
|
s = self.relu5(self.matmul2(c, v))
|
||||||
|
s = self.fc4(s)
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def test_AdamWeightDecayDynamicLR():
|
||||||
|
""" test_AdamWeightDecayDynamicLR """
|
||||||
|
auto_parallel_context().set_enable_parallel_optimizer(True)
|
||||||
|
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2)
|
||||||
|
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
||||||
|
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
||||||
|
net = Net()
|
||||||
|
net.set_train()
|
||||||
|
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
|
optimizer = AdamWeightDecayDynamicLR(net.trainable_params(), decay_steps=20, learning_rate=0.1)
|
||||||
|
|
||||||
|
net_with_loss = WithLossCell(net, loss)
|
||||||
|
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||||
|
_executor.compile(train_network, inputs, label)
|
||||||
|
|
||||||
|
|
||||||
|
def test_AdamWeightDecay():
|
||||||
|
""" test_AdamWeightDecayDynamicLR """
|
||||||
|
auto_parallel_context().set_enable_parallel_optimizer(True)
|
||||||
|
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2)
|
||||||
|
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
||||||
|
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
||||||
|
net = Net()
|
||||||
|
net.set_train()
|
||||||
|
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
|
optimizer = AdamWeightDecay(net.trainable_params(), learning_rate=0.1)
|
||||||
|
|
||||||
|
net_with_loss = WithLossCell(net, loss)
|
||||||
|
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||||
|
_executor.compile(train_network, inputs, label)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lamb_compile():
|
||||||
|
""" test_Lamb_compile """
|
||||||
|
auto_parallel_context().set_enable_parallel_optimizer(True)
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2)
|
||||||
|
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
||||||
|
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
||||||
|
net = Net()
|
||||||
|
net.set_train()
|
||||||
|
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
|
optimizer = Lamb(net.trainable_params(), decay_steps=10)
|
||||||
|
|
||||||
|
net_with_loss = WithLossCell(net, loss)
|
||||||
|
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||||
|
_executor.compile(train_network, inputs, label)
|
||||||
|
|
||||||
|
|
||||||
|
def test_edge_case():
|
||||||
|
""" test_edge_case """
|
||||||
|
auto_parallel_context().set_enable_parallel_optimizer(True)
|
||||||
|
net = Net()
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
context.set_auto_parallel_context(parallel_mode="stand_alone")
|
||||||
|
Lamb(net.trainable_params(), decay_steps=10)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
Adam(net.trainable_params(), learning_rate=0.1)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
context.set_auto_parallel_context(device_num=16)
|
||||||
|
Lamb(net.trainable_params(), decay_steps=10)
|
|
@ -81,6 +81,10 @@ def test_set_auto_parallel_context():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
set_algo_parameters(tensor_slice_align_size=1025)
|
set_algo_parameters(tensor_slice_align_size=1025)
|
||||||
|
|
||||||
|
auto_parallel_context().set_enable_parallel_optimizer(True)
|
||||||
|
assert auto_parallel_context().get_enable_parallel_optimizer() is True
|
||||||
|
assert not auto_parallel_context().get_all_reduce_fusion_split_indices()
|
||||||
|
|
||||||
|
|
||||||
def test_reset_auto_parallel_context():
|
def test_reset_auto_parallel_context():
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
|
|
Loading…
Reference in New Issue