diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.cc index 33468d46657..ae3182d97f7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.cc @@ -59,11 +59,11 @@ void AdamWeightDecayCPUKernel::LaunchFusedAdam(const std::vector &in auto beta1 = reinterpret_cast(inputs[4]->addr)[0]; auto beta2 = reinterpret_cast(inputs[5]->addr)[0]; auto epsilon = reinterpret_cast(inputs[6]->addr)[0]; - auto decay = reinterpret_cast(inputs[7]->addr); + auto decay = reinterpret_cast(inputs[7]->addr)[0]; auto gradient16 = reinterpret_cast(inputs[8]->addr); + const auto beta1_minus = 1 - beta1; + const auto beta2_minus = 1 - beta2; - float beta1_minus = 1 - beta1; - float beta2_minus = 1 - beta2; // multithreading size_t lens = inputs[0]->size > 0 ? static_cast(inputs[0]->size / sizeof(float)) : 1; std::function task; @@ -77,7 +77,7 @@ void AdamWeightDecayCPUKernel::LaunchFusedAdam(const std::vector &in m[i] += (temp - m[i]) * beta1_minus; v[i] += (temp * temp - v[i]) * beta2_minus; T update = m[i] / (std::sqrt(v[i]) + epsilon); - update += *decay * var[i]; + update += decay * var[i]; var[i] -= lr * update; } }; @@ -94,10 +94,10 @@ void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(const std::vector(inputs[4]->addr)[0]; auto beta2 = reinterpret_cast(inputs[5]->addr)[0]; auto epsilon = reinterpret_cast(inputs[6]->addr)[0]; - auto decay = reinterpret_cast(inputs[7]->addr); + auto decay = reinterpret_cast(inputs[7]->addr)[0]; auto gradient = reinterpret_cast(inputs[8]->addr); - auto beta1_minus = 1 - beta1; - auto beta2_minus = 1 - beta2; + const auto beta1_minus = 1 - beta1; + const auto beta2_minus = 1 - beta2; // multithreading size_t lens = inputs[0]->size > 0 ? static_cast(inputs[0]->size / sizeof(float)) : 1; @@ -110,7 +110,7 @@ void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(const std::vector #endif #ifdef ENABLE_AVX512 +const size_t kUnrollSize = 4; struct AVX_Data { __m512 data; }; -static inline void LoadStep4(struct AVX_Data *inp0, const float *inp1) { +static inline int LoadStep4(struct AVX_Data *inp0, const float *inp1, const size_t arrLen) { + if (arrLen != kUnrollSize) { + return NNACL_ERR; + } + if (inp0 == NULL || inp1 == NULL) { + return NNACL_NULL_PTR; + } inp0[0].data = _mm512_loadu_ps(inp1); inp0[1].data = _mm512_loadu_ps(inp1 + C16NUM); inp0[2].data = _mm512_loadu_ps(inp1 + C16NUM * 2); inp0[3].data = _mm512_loadu_ps(inp1 + C16NUM * 3); + return NNACL_OK; } -static inline void StoreStep4(float *inp0, struct AVX_Data *inp1) { +static inline int StoreStep4(float *inp0, struct AVX_Data *inp1, const size_t arrLen) { + if (arrLen != kUnrollSize) { + return NNACL_ERR; + } + if (inp0 == NULL || inp1 == NULL) { + return NNACL_NULL_PTR; + } _mm512_storeu_ps(inp0, inp1[0].data); _mm512_storeu_ps(inp0 + C16NUM, inp1[1].data); _mm512_storeu_ps(inp0 + C16NUM * 2, inp1[2].data); _mm512_storeu_ps(inp0 + C16NUM * 3, inp1[3].data); + return NNACL_OK; } #endif #ifdef __cplusplus @@ -52,9 +67,9 @@ int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, size_t start, size_t end, bool use_nesterov); int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient, size_t start, size_t end, bool use_nesterov); -int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float *decay, +int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, const float *gradient, size_t start, size_t end); -int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float *decay, +int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, const int16_t *gradient16, size_t start, size_t end); #ifdef __cplusplus } diff --git a/model_zoo/official/nlp/gpt/src/adam.py b/model_zoo/official/nlp/pangu_alpha/src/adam.py similarity index 86% rename from model_zoo/official/nlp/gpt/src/adam.py rename to model_zoo/official/nlp/pangu_alpha/src/adam.py index 3403ea95ec6..270e05b6af9 100644 --- a/model_zoo/official/nlp/gpt/src/adam.py +++ b/model_zoo/official/nlp/pangu_alpha/src/adam.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""AdamWeightDecay, a customized Adam for pangu1. Input: gradient.""" +"""AdamWeightDecay, a customized Adam for pangu alpha. Input: gradient.""" import numpy as np from mindspore.common import dtype as mstype @@ -35,15 +35,14 @@ def _update_run_kernel(opt, beta1, beta2, eps, lr, weight_decay, param, m, v, gr """ Update parameters by AdamWeightDecay op. """ + success = True if optim_filter: - op_cast = P.Cast() - gradient_fp32 = op_cast(gradient, mstype.float32) if decay_flags: - next_param = opt(param, m, v, lr, beta1, beta2, eps, F.cast(weight_decay, mstype.float32), gradient_fp32) + next_param = opt(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient) else: - next_param = opt(param, m, v, lr, beta1, beta2, eps, F.cast(0.0, mstype.float32), gradient_fp32) - return next_param - return gradient + next_param = opt(param, m, v, lr, beta1, beta2, eps, 0.0, gradient) + return F.depend(success, next_param) + return success def _check_param_value(beta1, beta2, eps, prim_name): @@ -145,18 +144,17 @@ class AdamWeightDecayOp(Optimizer): lr = self.get_lr() if self.is_group: if self.is_group_lr: - optim_result = self.map_(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps), - lr, self.weight_decay, self.parameters, self.moments1, self.moments2, - gradients, self.decay_flags, self.optim_filter) + optim_result = self.map_reverse(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps), + lr, self.weight_decay, self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) else: - optim_result = self.map_(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr), - self.weight_decay, self.parameters, self.moments1, self.moments2, - gradients, self.decay_flags, self.optim_filter) + optim_result = self.map_reverse(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr), + self.weight_decay, self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) else: - optim_result = self.map_(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr, - self.weight_decay), - self.parameters, self.moments1, self.moments2, gradients, self.decay_flags, - self.optim_filter) + optim_result = self.map_reverse(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr, + self.weight_decay), self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) if self.use_parallel: self.broadcast_params(optim_result) return optim_result diff --git a/model_zoo/official/nlp/pangu_alpha/src/utils.py b/model_zoo/official/nlp/pangu_alpha/src/utils.py index 29a26557fc2..0289a9e4e27 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/utils.py +++ b/model_zoo/official/nlp/pangu_alpha/src/utils.py @@ -326,6 +326,10 @@ def add_training_params(opt): default="adam", choices=["adam", "lamb"], help="select which optimizer to be used, default adam") + opt.add_argument("--opt_offload", + type=int, + default=0, + help="Enable optimizer status offload to host CPU, default is 0") opt.add_argument("--eod_id", type=int, default=6, diff --git a/model_zoo/official/nlp/pangu_alpha/train.py b/model_zoo/official/nlp/pangu_alpha/train.py index a12cbf73ba1..a1e06a8e266 100644 --- a/model_zoo/official/nlp/pangu_alpha/train.py +++ b/model_zoo/official/nlp/pangu_alpha/train.py @@ -30,6 +30,7 @@ import mindspore.common.dtype as mstype from mindspore.parallel import set_algo_parameters from mindspore.parallel._cost_model_context import _set_multi_subgraphs from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell +from src.adam import AdamWeightDecayOp from src.dataset import create_dataset from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell @@ -76,13 +77,31 @@ project_root = os.path.abspath( print('project_root:', project_root) +def set_weight_decay(params): + """ + Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest + """ + decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower() + decay_params = list(filter(decay_filter, params)) + other_params = list(filter(lambda x: not decay_filter(x), params)) + group_params = [{ + 'params': decay_params, + 'weight_decay': 1e-1 + }, { + 'params': other_params, + 'weight_decay': 0.0 + }, { + 'order_params': params + }] + return group_params + + def run_train(args_opt): r""" The main training process. """ # Set execution mode - context.set_context(mode=context.GRAPH_MODE, - device_target=args_opt.device_target) + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) context.set_context(variable_memory_max_size="30GB") # Set parallel context if args_opt.distribute == "true": @@ -149,22 +168,12 @@ def run_train(args_opt): warmup_steps=args_opt.warmup_step, decay_steps=200000) - # Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest - decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower() params = pangu_alpha.trainable_params() - decay_params = list(filter(decay_filter, params)) - other_params = list(filter(lambda x: not decay_filter(x), params)) - group_params = [{ - 'params': decay_params, - 'weight_decay': 1e-1 - }, { - 'params': other_params, - 'weight_decay': 0.0 - }, { - 'order_params': params - }] + group_params = set_weight_decay(params) if args_opt.optimizer == "lamb": optimizer = nn.Lamb(group_params, learning_rate=lr) + elif args_opt.opt_offload: + optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95) else: optimizer = FP32StateAdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95) # Initial scaling sens @@ -201,6 +210,7 @@ def run_train(args_opt): print("Dataset size: {}, actual_epoch_num: {}".format(ds.get_dataset_size(), actual_epoch_num), flush=True) model.train(actual_epoch_num, ds, callbacks=callback, sink_size=callback_size, dataset_sink_mode=True) + def run_train_pipeline(args_opt): r""" The main training process in pipeline. @@ -263,20 +273,11 @@ def run_train_pipeline(args_opt): lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr, warmup_steps=args_opt.warmup_step, decay_steps=args_opt.decay_steps) params = pangu_alpha.infer_param_pipeline_stage() - decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower() - decay_params = list(filter(decay_filter, params)) - other_params = list(filter(lambda x: not decay_filter(x), params)) - group_params = [{ - 'params': decay_params, - 'weight_decay': 1e-1 - }, { - 'params': other_params, - 'weight_decay': 0.0 - }, { - 'order_params': params - }] + group_params = set_weight_decay(params) if args_opt.optimizer == "lamb": optimizer = nn.Lamb(group_params, learning_rate=lr) + elif args_opt.opt_offload: + optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95) else: optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8) @@ -297,6 +298,7 @@ def run_train_pipeline(args_opt): model.train(actual_epoch_num, ds, callbacks=callback, sink_size=callback_size, dataset_sink_mode=True) + if __name__ == "__main__": opt = get_args() set_parse(opt) diff --git a/tests/st/ops/cpu/test_adam_weight_decay_op.py b/tests/st/ops/cpu/test_adam_weight_decay_op.py index 2da5f6adbe2..d70e94ae7bf 100644 --- a/tests/st/ops/cpu/test_adam_weight_decay_op.py +++ b/tests/st/ops/cpu/test_adam_weight_decay_op.py @@ -22,7 +22,7 @@ from mindspore import Tensor from mindspore.nn import Dense from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.ops import operations as P -from model_zoo.official.nlp.gpt.src.adam import AdamWeightDecayOp +from model_zoo.official.nlp.pangu_alpha.src.adam import AdamWeightDecayOp context.set_context(mode=context.GRAPH_MODE, device_target="CPU")