!18769 fix adam/deltadam optimizer bug

Merge pull request !18769 from wangyanling/optimizer
This commit is contained in:
i-robot 2021-06-24 01:43:12 +00:00 committed by Gitee
commit 33d4fc768c
3 changed files with 85 additions and 104 deletions

View File

@ -74,6 +74,7 @@ void AdamCPUKernel::LaunchAdamNnacl(const std::vector<kernel::AddressPtr> &input
MS_LOG(EXCEPTION) << "The beta1_power can't be set 1.";
}
float new_lr = lr * std::sqrt(1.0 - beta2_power) / (1 - beta1_power);
// multithreading
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
auto task = [&](size_t start, size_t end) {

View File

@ -29,63 +29,55 @@ int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2,
size_t start, size_t end, bool use_nesterov) {
size_t c1 = start;
#ifdef ENABLE_AVX
float coeff1 = 1 - beta1;
float coeff2 = 1 - beta2;
const float *m_ptr = m;
const float *v_ptr = v;
float *var_ptr = var;
const float *gradient_ptr = gradient;
size_t c8 = ((end - start) / C8NUM) * C8NUM;
__m256 avx_r0, avx_r1, avx_r2, avx_r3, avx_r4, avx_r5, avx_r6, gradient_r;
__m256 coeff1_r = _mm256_set1_ps(1 - beta1);
__m256 coeff2_r = _mm256_set1_ps(1 - beta2);
__m256 beta1_r = _mm256_set1_ps(beta1);
__m256 lr_r = _mm256_set1_ps(lr);
__m256 epsi_r = _mm256_set1_ps(epsilon);
for (; c1 < c8; c1 += C8NUM) {
avx_r0 = _mm256_set1_ps(coeff1);
gradient_r = _mm256_loadu_ps(gradient_ptr);
avx_r2 = _mm256_loadu_ps(m_ptr);
avx_r3 = _mm256_sub_ps(gradient_r, avx_r2);
avx_r4 = _mm256_mul_ps(avx_r3, avx_r0);
avx_r3 = _mm256_add_ps(avx_r4, avx_r2); // m[i]~m[i+8]
float *var_ptr = var + start;
float *m_ptr = m + start;
float *v_ptr = v + start;
const float *grad_ptr = gradient + start;
avx_r2 = _mm256_mul_ps(gradient_r, gradient_r);
avx_r4 = _mm256_loadu_ps(v_ptr);
avx_r5 = _mm256_sub_ps(avx_r2, avx_r4);
avx_r1 = _mm256_set1_ps(coeff2);
avx_r2 = _mm256_mul_ps(avx_r5, avx_r1);
avx_r5 = _mm256_add_ps(avx_r4, avx_r2); // v[i]~v[i+8]
__m256 avx_r0, avx_r1;
__m256 var_r, m_r, v_r, grad_r;
for (; c1 < start + c8; c1 += C8NUM) {
grad_r = _mm256_loadu_ps(grad_ptr);
m_r = _mm256_loadu_ps(m_ptr);
avx_r0 = _mm256_sub_ps(grad_r, m_r);
avx_r1 = _mm256_mul_ps(avx_r0, coeff1_r);
m_r = _mm256_add_ps(m_r, avx_r1);
_mm256_storeu_ps(m_ptr, m_r);
v_r = _mm256_loadu_ps(v_ptr);
avx_r0 = _mm256_sub_ps(_mm256_mul_ps(grad_r, grad_r), v_r);
v_r = _mm256_add_ps(v_r, _mm256_mul_ps(avx_r0, coeff2_r));
_mm256_storeu_ps(v_ptr, v_r);
if (use_nesterov) {
avx_r1 = _mm256_set1_ps(beta1);
avx_r2 = _mm256_mul_ps(avx_r3, avx_r1);
avx_r4 = _mm256_mul_ps(gradient_r, avx_r0);
avx_r6 = _mm256_add_ps(avx_r2, avx_r4);
avx_r0 = _mm256_set1_ps(lr);
avx_r2 = _mm256_mul_ps(avx_r6, avx_r0);
avx_r0 = _mm256_add_ps(_mm256_mul_ps(m_r, beta1_r), _mm256_mul_ps(coeff1_r, grad_r));
avx_r1 = _mm256_mul_ps(lr_r, avx_r0);
avx_r0 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r);
__m256 avx_r2 = _mm256_div_ps(avx_r1, avx_r0);
avx_r0 = _mm256_set1_ps(epsilon);
avx_r1 = _mm256_sqrt_ps(avx_r5);
avx_r4 = _mm256_add_ps(avx_r0, avx_r1);
avx_r0 = _mm256_div_ps(avx_r2, avx_r2);
avx_r1 = _mm256_loadu_ps(var_ptr);
avx_r2 = _mm256_sub_ps(avx_r1, avx_r0);
_mm256_storeu_ps(var_ptr, avx_r2);
var_r = _mm256_loadu_ps(var_ptr);
var_r = _mm256_sub_ps(var_r, avx_r2);
_mm256_storeu_ps(var_ptr, var_r);
} else {
avx_r0 = _mm256_set1_ps(lr);
avx_r1 = _mm256_mul_ps(avx_r3, avx_r0);
avx_r0 = _mm256_set1_ps(epsilon);
avx_r2 = _mm256_sqrt_ps(avx_r5);
avx_r4 = _mm256_add_ps(avx_r0, avx_r2);
avx_r0 = _mm256_div_ps(avx_r1, avx_r4);
avx_r1 = _mm256_loadu_ps(var_ptr);
avx_r3 = _mm256_sub_ps(avx_r1, avx_r0);
_mm256_storeu_ps(var_ptr, avx_r3);
avx_r0 = _mm256_mul_ps(lr_r, m_r);
avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r);
__m256 avx_r2 = _mm256_div_ps(avx_r0, avx_r1);
var_r = _mm256_loadu_ps(var_ptr);
var_r = _mm256_sub_ps(var_r, avx_r2);
_mm256_storeu_ps(var_ptr, var_r);
}
m_ptr += C8NUM;
v_ptr += C8NUM;
var_ptr += C8NUM;
gradient_ptr += C8NUM;
grad_ptr += C8NUM;
}
#endif
@ -106,62 +98,50 @@ int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float
const float *gradient, size_t start, size_t end, bool use_nesterov) {
size_t c1 = start;
#ifdef ENABLE_AVX
float coeff1 = 1 - beta1;
float coeff2 = 1 - beta2;
float *m_ptr = m;
float *v_ptr = v;
float *delta_ptr = delta;
const float *gradient_ptr = gradient;
size_t c8 = ((end - start) / C8NUM) * C8NUM;
__m256 coeff1_r = _mm256_set1_ps(1.0f - beta1);
__m256 coeff2_r = _mm256_set1_ps(1.0f - beta2);
__m256 beta1_r = _mm256_set1_ps(beta1);
__m256 beta2_r = _mm256_set1_ps(beta2);
__m256 lr_r = _mm256_set1_ps(-lr);
__m256 epsi_r = _mm256_set1_ps(epsilon);
__m256 gradient_r0, m_r1, v_r2, beta1_r3, beta2_r4, var_r5, var_r6, var_r7;
for (; c1 < c8 + start; c1 += C8NUM) {
gradient_r0 = _mm256_loadu_ps(gradient_ptr); // static
beta1_r3 = _mm256_set1_ps(beta1); // static
var_r5 = _mm256_loadu_ps(m_ptr);
var_r6 = _mm256_mul_ps(beta1_r3, var_r5); // m[i] = m[i] * beta1
var_r7 = _mm256_set1_ps(coeff1);
var_r5 = _mm256_mul_ps(var_r7, gradient_r0); //
m_r1 = _mm256_add_ps(var_r6, var_r5);
_mm256_storeu_ps(m_ptr, m_r1);
float *m_ptr = m + start;
float *v_ptr = v + start;
float *delta_ptr = delta + start;
const float *gradient_ptr = gradient + start;
beta2_r4 = _mm256_set1_ps(beta2); // static
var_r5 = _mm256_loadu_ps(v_ptr);
var_r6 = _mm256_mul_ps(beta2_r4, var_r5); // v[i] * beta2
var_r7 = _mm256_set1_ps(coeff2);
var_r5 = _mm256_mul_ps(var_r7, gradient_r0);
var_r7 = _mm256_mul_ps(var_r5, gradient_r0);
v_r2 = _mm256_add_ps(var_r7, var_r6);
_mm256_storeu_ps(v_ptr, v_r2);
__m256 m_r, v_r, delta_r, grad_r;
__m256 avx_r0, avx_r1;
for (; c1 < start + c8; c1 += C8NUM) {
m_r = _mm256_loadu_ps(m_ptr);
avx_r0 = _mm256_mul_ps(m_r, beta1_r);
grad_r = _mm256_loadu_ps(gradient_ptr);
m_r = _mm256_add_ps(avx_r0, _mm256_mul_ps(coeff1_r, grad_r));
_mm256_storeu_ps(m_ptr, m_r);
v_r = _mm256_loadu_ps(v_ptr);
avx_r0 = _mm256_mul_ps(v_r, beta2_r);
avx_r1 = _mm256_mul_ps(_mm256_mul_ps(coeff2_r, grad_r), grad_r);
v_r = _mm256_add_ps(avx_r0, avx_r1);
_mm256_storeu_ps(v_ptr, v_r);
if (use_nesterov) {
var_r5 = _mm256_mul_ps(beta1_r3, m_r1);
var_r6 = _mm256_set1_ps(coeff1);
var_r7 = _mm256_mul_ps(gradient_r0, var_r6);
var_r6 = _mm256_add_ps(var_r5, var_r7); // m[i] * beta1 + (1 - beta1) * grad[i]
var_r5 = _mm256_set1_ps(lr);
var_r7 = _mm256_mul_ps(var_r6, var_r5);
var_r5 = _mm256_set1_ps(epsilon);
var_r6 = _mm256_sqrt_ps(v_r2);
v_r2 = _mm256_add_ps(var_r5, var_r6);
var_r5 = _mm256_div_ps(var_r7, v_r2);
var_r6 = _mm256_set1_ps(0.f);
var_r7 = _mm256_sub_ps(var_r6, var_r5);
_mm256_storeu_ps(delta_ptr, var_r7);
avx_r0 = _mm256_add_ps(_mm256_mul_ps(m_r, beta1_r), _mm256_mul_ps(coeff1_r, grad_r));
avx_r0 = _mm256_mul_ps(lr_r, avx_r0);
avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r);
delta_r = _mm256_div_ps(avx_r0, avx_r1);
_mm256_storeu_ps(delta_ptr, delta_r);
} else {
var_r5 = _mm256_set1_ps(lr);
var_r6 = _mm256_mul_ps(var_r5, m_r1);
var_r7 = _mm256_set1_ps(epsilon);
var_r5 = _mm256_sqrt_ps(v_r2);
v_r2 = _mm256_add_ps(var_r5, var_r7);
var_r5 = _mm256_div_ps(var_r6, v_r2);
var_r6 = _mm256_set1_ps(0.f);
var_r7 = _mm256_sub_ps(var_r6, var_r5);
_mm256_storeu_ps(delta_ptr, var_r7);
avx_r0 = _mm256_mul_ps(lr_r, m_r);
avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r);
delta_r = _mm256_div_ps(avx_r0, avx_r1);
_mm256_storeu_ps(delta_ptr, delta_r);
}
m_ptr += C8NUM;
v_ptr += C8NUM;
delta_ptr += C8NUM;
gradient_ptr += C8NUM;
}
#endif

View File

@ -29,10 +29,10 @@ int RMSPropUnuseCenterFp32(float *variable, float *mean_square, float *moment, f
size_t c1 = start;
#ifdef ENABLE_AVX
size_t c8 = ((end - start) / C8NUM) * C8NUM;
float *variable_ptr = variable;
float *mean_square_ptr = mean_square;
float *gradients_ptr = gradients;
float *moment_ptr = moment;
float *variable_ptr = variable + start;
float *mean_square_ptr = mean_square + start;
float *gradients_ptr = gradients + start;
float *moment_ptr = moment + start;
__m256 decay_r = _mm256_set1_ps(1.0 - decay);
__m256 momentum_r = _mm256_set1_ps(momentum);
@ -78,11 +78,11 @@ int RMSPropUseCenterFp32(float *variable, float *mean_square, float *moment, flo
size_t c1 = start;
#ifdef ENABLE_AVX
size_t c8 = ((end - start) / C8NUM) * C8NUM;
float *variable_ptr = variable;
float *mean_gradients_ptr = mean_gradients;
float *mean_square_ptr = mean_square;
float *moment_ptr = moment;
float *gradients_ptr = gradients;
float *variable_ptr = variable + start;
float *mean_gradients_ptr = mean_gradients + start;
float *mean_square_ptr = mean_square + start;
float *moment_ptr = moment + start;
float *gradients_ptr = gradients + start;
__m256 decay_r = _mm256_set1_ps(1.0 - decay);
__m256 momentum_r = _mm256_set1_ps(momentum);