!20111 add adam offload for pangu & fix AdamWeightDecay nnacl

Merge pull request !20111 from zhaosida/zsd_adam_simd
This commit is contained in:
i-robot 2021-07-15 08:53:37 +00:00 committed by Gitee
commit eaf9588ac9
8 changed files with 110 additions and 81 deletions

View File

@ -59,11 +59,11 @@ void AdamWeightDecayCPUKernel::LaunchFusedAdam(const std::vector<AddressPtr> &in
auto beta1 = reinterpret_cast<T *>(inputs[4]->addr)[0];
auto beta2 = reinterpret_cast<T *>(inputs[5]->addr)[0];
auto epsilon = reinterpret_cast<T *>(inputs[6]->addr)[0];
auto decay = reinterpret_cast<T *>(inputs[7]->addr);
auto decay = reinterpret_cast<T *>(inputs[7]->addr)[0];
auto gradient16 = reinterpret_cast<S *>(inputs[8]->addr);
const auto beta1_minus = 1 - beta1;
const auto beta2_minus = 1 - beta2;
float beta1_minus = 1 - beta1;
float beta2_minus = 1 - beta2;
// multithreading
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
std::function<void(size_t, size_t)> task;
@ -77,7 +77,7 @@ void AdamWeightDecayCPUKernel::LaunchFusedAdam(const std::vector<AddressPtr> &in
m[i] += (temp - m[i]) * beta1_minus;
v[i] += (temp * temp - v[i]) * beta2_minus;
T update = m[i] / (std::sqrt(v[i]) + epsilon);
update += *decay * var[i];
update += decay * var[i];
var[i] -= lr * update;
}
};
@ -94,10 +94,10 @@ void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(const std::vector<AddressPt
auto beta1 = reinterpret_cast<T *>(inputs[4]->addr)[0];
auto beta2 = reinterpret_cast<T *>(inputs[5]->addr)[0];
auto epsilon = reinterpret_cast<T *>(inputs[6]->addr)[0];
auto decay = reinterpret_cast<T *>(inputs[7]->addr);
auto decay = reinterpret_cast<T *>(inputs[7]->addr)[0];
auto gradient = reinterpret_cast<T *>(inputs[8]->addr);
auto beta1_minus = 1 - beta1;
auto beta2_minus = 1 - beta2;
const auto beta1_minus = 1 - beta1;
const auto beta2_minus = 1 - beta2;
// multithreading
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
@ -110,7 +110,7 @@ void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(const std::vector<AddressPt
m[i] += (gradient[i] - m[i]) * beta1_minus;
v[i] += (gradient[i] * gradient[i] - v[i]) * beta2_minus;
T update = m[i] / (std::sqrt(v[i]) + epsilon);
update += decay[0] * var[i];
update += decay * var[i];
var[i] -= lr * update;
}
};

View File

@ -16,6 +16,9 @@ endif()
if("${X86_64_SIMD}" STREQUAL "avx")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -mavx -mavx2 -mfma")
endif()
if("${X86_64_SIMD}" STREQUAL "avx512")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -mavx -mavx2 -mfma")
endif()
if("${X86_64_SIMD}" STREQUAL "sse")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1")
endif()
@ -52,6 +55,13 @@ if("${X86_64_SIMD}" STREQUAL "avx")
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
endif()
if("${X86_64_SIMD}" STREQUAL "avx512")
file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/intrinsics/sse/*.c
${NNACL_DIR}/intrinsics/avx/*.c
${NNACL_DIR}/assembly/avx/*.S)
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
endif()
if(APPLE)
set_source_files_properties(${ASSEMBLY_SRC} PROPERTIES COMPILE_FLAGS "-x assembler-with-cpp")
endif()
@ -74,7 +84,7 @@ if(ENABLE_CPU)
elseif("${X86_64_SIMD}" STREQUAL "avx")
target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE ENABLE_AVX)
elseif("${X86_64_SIMD}" STREQUAL "avx512")
target_compile_definitions(nnacl_mid PRIVATE ENABLE_AVX512)
target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE ENABLE_AVX ENABLE_AVX512)
target_compile_options(nnacl_mid PRIVATE -mavx512f)
endif()
target_compile_options(nnacl_mid PRIVATE -fPIC)

View File

@ -152,12 +152,12 @@ int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float
return NNACL_OK;
}
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float *decay,
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const float *gradient, size_t start, size_t end) {
size_t c1 = start;
#ifdef ENABLE_AVX512
float beta1_minus = 1 - beta1;
float beta2_minus = 1 - beta2;
const float beta1_minus = 1 - beta1;
const float beta2_minus = 1 - beta2;
struct AVX_Data beta1_r, beta2_r, beta1_minus_r, beta2_minus_r, lr_neg_r, epsilon_r, decay_r;
beta1_r.data = _mm512_set1_ps(beta1);
beta2_r.data = _mm512_set1_ps(beta2);
@ -165,7 +165,7 @@ int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, f
beta2_minus_r.data = _mm512_set1_ps(beta2_minus);
lr_neg_r.data = _mm512_set1_ps(-lr);
epsilon_r.data = _mm512_set1_ps(epsilon);
decay_r.data = _mm512_set1_ps(*decay);
decay_r.data = _mm512_set1_ps(decay);
size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
size_t c64 = ((end - start) / C64NUM) * C64NUM + start;
@ -175,11 +175,11 @@ int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, f
float *v_ptr = v + start;
for (; c1 < c64; c1 += C64NUM) {
struct AVX_Data g_r[4], var_r[4], m_r[4], v_r[4];
LoadStep4(g_r, gradient_ptr);
LoadStep4(var_r, var_ptr);
LoadStep4(m_r, m_ptr);
LoadStep4(v_r, v_ptr);
struct AVX_Data g_r[kUnrollSize], var_r[kUnrollSize], m_r[kUnrollSize], v_r[kUnrollSize];
LoadStep4(g_r, gradient_ptr, kUnrollSize);
LoadStep4(var_r, var_ptr, kUnrollSize);
LoadStep4(m_r, m_ptr, kUnrollSize);
LoadStep4(v_r, v_ptr, kUnrollSize);
m_r[0].data = _mm512_mul_ps(m_r[0].data, beta1_r.data);
m_r[1].data = _mm512_mul_ps(m_r[1].data, beta1_r.data);
@ -221,9 +221,9 @@ int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, f
var_r[2].data = _mm512_fmadd_ps(g_r[2].data, lr_neg_r.data, var_r[2].data);
var_r[3].data = _mm512_fmadd_ps(g_r[3].data, lr_neg_r.data, var_r[3].data);
StoreStep4(var_ptr, var_r);
StoreStep4(m_ptr, m_r);
StoreStep4(v_ptr, v_r);
StoreStep4(var_ptr, var_r, kUnrollSize);
StoreStep4(m_ptr, m_r, kUnrollSize);
StoreStep4(v_ptr, v_r, kUnrollSize);
gradient_ptr += C64NUM;
var_ptr += C64NUM;
@ -260,12 +260,12 @@ int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, f
return c1;
}
int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float *decay,
int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const int16_t *gradient16, size_t start, size_t end) {
size_t c1 = start;
#ifdef ENABLE_AVX512
float beta1_minus = 1 - beta1;
float beta2_minus = 1 - beta2;
const float beta1_minus = 1 - beta1;
const float beta2_minus = 1 - beta2;
struct AVX_Data beta1_r, beta2_r, beta1_minus_r, beta2_minus_r, lr_neg_r, epsilon_r, decay_r;
beta1_r.data = _mm512_set1_ps(beta1);
beta2_r.data = _mm512_set1_ps(beta2);
@ -273,7 +273,7 @@ int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float b
beta2_minus_r.data = _mm512_set1_ps(beta2_minus);
lr_neg_r.data = _mm512_set1_ps(-lr);
epsilon_r.data = _mm512_set1_ps(epsilon);
decay_r.data = _mm512_set1_ps(*decay);
decay_r.data = _mm512_set1_ps(decay);
size_t c16 = ((end - start) / C16NUM) * C16NUM + start;
size_t c64 = ((end - start) / C64NUM) * C64NUM + start;
@ -283,15 +283,15 @@ int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float b
float *v_ptr = v + start;
for (; c1 < c64; c1 += C64NUM) {
struct AVX_Data g_r[4], var_r[4], m_r[4], v_r[4];
struct AVX_Data g_r[kUnrollSize], var_r[kUnrollSize], m_r[kUnrollSize], v_r[kUnrollSize];
g_r[0].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
g_r[1].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr + C16NUM)));
g_r[2].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr + C16NUM * 2)));
g_r[3].data = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr + C16NUM * 3)));
LoadStep4(var_r, var_ptr);
LoadStep4(m_r, m_ptr);
LoadStep4(v_r, v_ptr);
LoadStep4(var_r, var_ptr, kUnrollSize);
LoadStep4(m_r, m_ptr, kUnrollSize);
LoadStep4(v_r, v_ptr, kUnrollSize);
m_r[0].data = _mm512_mul_ps(m_r[0].data, beta1_r.data);
m_r[1].data = _mm512_mul_ps(m_r[1].data, beta1_r.data);
@ -333,9 +333,9 @@ int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float b
var_r[2].data = _mm512_fmadd_ps(g_r[2].data, lr_neg_r.data, var_r[2].data);
var_r[3].data = _mm512_fmadd_ps(g_r[3].data, lr_neg_r.data, var_r[3].data);
StoreStep4(var_ptr, var_r);
StoreStep4(m_ptr, m_r);
StoreStep4(v_ptr, v_r);
StoreStep4(var_ptr, var_r, kUnrollSize);
StoreStep4(m_ptr, m_r, kUnrollSize);
StoreStep4(v_ptr, v_r, kUnrollSize);
gradient16_ptr += C64NUM;
var_ptr += C64NUM;

View File

@ -27,22 +27,37 @@
#include <immintrin.h>
#endif
#ifdef ENABLE_AVX512
const size_t kUnrollSize = 4;
struct AVX_Data {
__m512 data;
};
static inline void LoadStep4(struct AVX_Data *inp0, const float *inp1) {
static inline int LoadStep4(struct AVX_Data *inp0, const float *inp1, const size_t arrLen) {
if (arrLen != kUnrollSize) {
return NNACL_ERR;
}
if (inp0 == NULL || inp1 == NULL) {
return NNACL_NULL_PTR;
}
inp0[0].data = _mm512_loadu_ps(inp1);
inp0[1].data = _mm512_loadu_ps(inp1 + C16NUM);
inp0[2].data = _mm512_loadu_ps(inp1 + C16NUM * 2);
inp0[3].data = _mm512_loadu_ps(inp1 + C16NUM * 3);
return NNACL_OK;
}
static inline void StoreStep4(float *inp0, struct AVX_Data *inp1) {
static inline int StoreStep4(float *inp0, struct AVX_Data *inp1, const size_t arrLen) {
if (arrLen != kUnrollSize) {
return NNACL_ERR;
}
if (inp0 == NULL || inp1 == NULL) {
return NNACL_NULL_PTR;
}
_mm512_storeu_ps(inp0, inp1[0].data);
_mm512_storeu_ps(inp0 + C16NUM, inp1[1].data);
_mm512_storeu_ps(inp0 + C16NUM * 2, inp1[2].data);
_mm512_storeu_ps(inp0 + C16NUM * 3, inp1[3].data);
return NNACL_OK;
}
#endif
#ifdef __cplusplus
@ -52,9 +67,9 @@ int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2,
size_t start, size_t end, bool use_nesterov);
int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
const float *gradient, size_t start, size_t end, bool use_nesterov);
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float *decay,
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const float *gradient, size_t start, size_t end);
int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float *decay,
int FusedAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const int16_t *gradient16, size_t start, size_t end);
#ifdef __cplusplus
}

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""AdamWeightDecay, a customized Adam for pangu1. Input: gradient."""
"""AdamWeightDecay, a customized Adam for pangu alpha. Input: gradient."""
import numpy as np
from mindspore.common import dtype as mstype
@ -35,15 +35,14 @@ def _update_run_kernel(opt, beta1, beta2, eps, lr, weight_decay, param, m, v, gr
"""
Update parameters by AdamWeightDecay op.
"""
success = True
if optim_filter:
op_cast = P.Cast()
gradient_fp32 = op_cast(gradient, mstype.float32)
if decay_flags:
next_param = opt(param, m, v, lr, beta1, beta2, eps, F.cast(weight_decay, mstype.float32), gradient_fp32)
next_param = opt(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient)
else:
next_param = opt(param, m, v, lr, beta1, beta2, eps, F.cast(0.0, mstype.float32), gradient_fp32)
return next_param
return gradient
next_param = opt(param, m, v, lr, beta1, beta2, eps, 0.0, gradient)
return F.depend(success, next_param)
return success
def _check_param_value(beta1, beta2, eps, prim_name):
@ -145,18 +144,17 @@ class AdamWeightDecayOp(Optimizer):
lr = self.get_lr()
if self.is_group:
if self.is_group_lr:
optim_result = self.map_(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps),
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
optim_result = self.map_reverse(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps),
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.map_(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr),
self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
optim_result = self.map_reverse(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr),
self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.map_(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay),
self.parameters, self.moments1, self.moments2, gradients, self.decay_flags,
self.optim_filter)
optim_result = self.map_reverse(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay), self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
if self.use_parallel:
self.broadcast_params(optim_result)
return optim_result

View File

@ -326,6 +326,10 @@ def add_training_params(opt):
default="adam",
choices=["adam", "lamb"],
help="select which optimizer to be used, default adam")
opt.add_argument("--opt_offload",
type=int,
default=0,
help="Enable optimizer status offload to host CPU, default is 0")
opt.add_argument("--eod_id",
type=int,
default=6,

View File

@ -30,6 +30,7 @@ import mindspore.common.dtype as mstype
from mindspore.parallel import set_algo_parameters
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell
from src.adam import AdamWeightDecayOp
from src.dataset import create_dataset
from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell
@ -76,13 +77,31 @@ project_root = os.path.abspath(
print('project_root:', project_root)
def set_weight_decay(params):
"""
Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest
"""
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
decay_params = list(filter(decay_filter, params))
other_params = list(filter(lambda x: not decay_filter(x), params))
group_params = [{
'params': decay_params,
'weight_decay': 1e-1
}, {
'params': other_params,
'weight_decay': 0.0
}, {
'order_params': params
}]
return group_params
def run_train(args_opt):
r"""
The main training process.
"""
# Set execution mode
context.set_context(mode=context.GRAPH_MODE,
device_target=args_opt.device_target)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(variable_memory_max_size="30GB")
# Set parallel context
if args_opt.distribute == "true":
@ -149,22 +168,12 @@ def run_train(args_opt):
warmup_steps=args_opt.warmup_step,
decay_steps=200000)
# Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
params = pangu_alpha.trainable_params()
decay_params = list(filter(decay_filter, params))
other_params = list(filter(lambda x: not decay_filter(x), params))
group_params = [{
'params': decay_params,
'weight_decay': 1e-1
}, {
'params': other_params,
'weight_decay': 0.0
}, {
'order_params': params
}]
group_params = set_weight_decay(params)
if args_opt.optimizer == "lamb":
optimizer = nn.Lamb(group_params, learning_rate=lr)
elif args_opt.opt_offload:
optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
else:
optimizer = FP32StateAdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
# Initial scaling sens
@ -201,6 +210,7 @@ def run_train(args_opt):
print("Dataset size: {}, actual_epoch_num: {}".format(ds.get_dataset_size(), actual_epoch_num), flush=True)
model.train(actual_epoch_num, ds, callbacks=callback, sink_size=callback_size, dataset_sink_mode=True)
def run_train_pipeline(args_opt):
r"""
The main training process in pipeline.
@ -263,20 +273,11 @@ def run_train_pipeline(args_opt):
lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr,
warmup_steps=args_opt.warmup_step, decay_steps=args_opt.decay_steps)
params = pangu_alpha.infer_param_pipeline_stage()
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
decay_params = list(filter(decay_filter, params))
other_params = list(filter(lambda x: not decay_filter(x), params))
group_params = [{
'params': decay_params,
'weight_decay': 1e-1
}, {
'params': other_params,
'weight_decay': 0.0
}, {
'order_params': params
}]
group_params = set_weight_decay(params)
if args_opt.optimizer == "lamb":
optimizer = nn.Lamb(group_params, learning_rate=lr)
elif args_opt.opt_offload:
optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
else:
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
@ -297,6 +298,7 @@ def run_train_pipeline(args_opt):
model.train(actual_epoch_num, ds, callbacks=callback,
sink_size=callback_size, dataset_sink_mode=True)
if __name__ == "__main__":
opt = get_args()
set_parse(opt)

View File

@ -22,7 +22,7 @@ from mindspore import Tensor
from mindspore.nn import Dense
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.ops import operations as P
from model_zoo.official.nlp.gpt.src.adam import AdamWeightDecayOp
from model_zoo.official.nlp.pangu_alpha.src.adam import AdamWeightDecayOp
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")