forked from mindspore-Ecosystem/mindspore
!18769 fix adam/deltadam optimizer bug
Merge pull request !18769 from wangyanling/optimizer
This commit is contained in:
commit
33d4fc768c
|
@ -74,6 +74,7 @@ void AdamCPUKernel::LaunchAdamNnacl(const std::vector<kernel::AddressPtr> &input
|
|||
MS_LOG(EXCEPTION) << "The beta1_power can't be set 1.";
|
||||
}
|
||||
float new_lr = lr * std::sqrt(1.0 - beta2_power) / (1 - beta1_power);
|
||||
|
||||
// multithreading
|
||||
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
|
|
|
@ -29,63 +29,55 @@ int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2,
|
|||
size_t start, size_t end, bool use_nesterov) {
|
||||
size_t c1 = start;
|
||||
#ifdef ENABLE_AVX
|
||||
float coeff1 = 1 - beta1;
|
||||
float coeff2 = 1 - beta2;
|
||||
const float *m_ptr = m;
|
||||
const float *v_ptr = v;
|
||||
float *var_ptr = var;
|
||||
const float *gradient_ptr = gradient;
|
||||
size_t c8 = ((end - start) / C8NUM) * C8NUM;
|
||||
__m256 avx_r0, avx_r1, avx_r2, avx_r3, avx_r4, avx_r5, avx_r6, gradient_r;
|
||||
__m256 coeff1_r = _mm256_set1_ps(1 - beta1);
|
||||
__m256 coeff2_r = _mm256_set1_ps(1 - beta2);
|
||||
__m256 beta1_r = _mm256_set1_ps(beta1);
|
||||
__m256 lr_r = _mm256_set1_ps(lr);
|
||||
__m256 epsi_r = _mm256_set1_ps(epsilon);
|
||||
|
||||
for (; c1 < c8; c1 += C8NUM) {
|
||||
avx_r0 = _mm256_set1_ps(coeff1);
|
||||
gradient_r = _mm256_loadu_ps(gradient_ptr);
|
||||
avx_r2 = _mm256_loadu_ps(m_ptr);
|
||||
avx_r3 = _mm256_sub_ps(gradient_r, avx_r2);
|
||||
avx_r4 = _mm256_mul_ps(avx_r3, avx_r0);
|
||||
avx_r3 = _mm256_add_ps(avx_r4, avx_r2); // m[i]~m[i+8]
|
||||
float *var_ptr = var + start;
|
||||
float *m_ptr = m + start;
|
||||
float *v_ptr = v + start;
|
||||
const float *grad_ptr = gradient + start;
|
||||
|
||||
avx_r2 = _mm256_mul_ps(gradient_r, gradient_r);
|
||||
avx_r4 = _mm256_loadu_ps(v_ptr);
|
||||
avx_r5 = _mm256_sub_ps(avx_r2, avx_r4);
|
||||
avx_r1 = _mm256_set1_ps(coeff2);
|
||||
avx_r2 = _mm256_mul_ps(avx_r5, avx_r1);
|
||||
avx_r5 = _mm256_add_ps(avx_r4, avx_r2); // v[i]~v[i+8]
|
||||
__m256 avx_r0, avx_r1;
|
||||
__m256 var_r, m_r, v_r, grad_r;
|
||||
|
||||
for (; c1 < start + c8; c1 += C8NUM) {
|
||||
grad_r = _mm256_loadu_ps(grad_ptr);
|
||||
m_r = _mm256_loadu_ps(m_ptr);
|
||||
avx_r0 = _mm256_sub_ps(grad_r, m_r);
|
||||
avx_r1 = _mm256_mul_ps(avx_r0, coeff1_r);
|
||||
m_r = _mm256_add_ps(m_r, avx_r1);
|
||||
_mm256_storeu_ps(m_ptr, m_r);
|
||||
|
||||
v_r = _mm256_loadu_ps(v_ptr);
|
||||
avx_r0 = _mm256_sub_ps(_mm256_mul_ps(grad_r, grad_r), v_r);
|
||||
v_r = _mm256_add_ps(v_r, _mm256_mul_ps(avx_r0, coeff2_r));
|
||||
_mm256_storeu_ps(v_ptr, v_r);
|
||||
|
||||
if (use_nesterov) {
|
||||
avx_r1 = _mm256_set1_ps(beta1);
|
||||
avx_r2 = _mm256_mul_ps(avx_r3, avx_r1);
|
||||
avx_r4 = _mm256_mul_ps(gradient_r, avx_r0);
|
||||
avx_r6 = _mm256_add_ps(avx_r2, avx_r4);
|
||||
avx_r0 = _mm256_set1_ps(lr);
|
||||
avx_r2 = _mm256_mul_ps(avx_r6, avx_r0);
|
||||
avx_r0 = _mm256_add_ps(_mm256_mul_ps(m_r, beta1_r), _mm256_mul_ps(coeff1_r, grad_r));
|
||||
avx_r1 = _mm256_mul_ps(lr_r, avx_r0);
|
||||
avx_r0 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r);
|
||||
__m256 avx_r2 = _mm256_div_ps(avx_r1, avx_r0);
|
||||
|
||||
avx_r0 = _mm256_set1_ps(epsilon);
|
||||
avx_r1 = _mm256_sqrt_ps(avx_r5);
|
||||
avx_r4 = _mm256_add_ps(avx_r0, avx_r1);
|
||||
|
||||
avx_r0 = _mm256_div_ps(avx_r2, avx_r2);
|
||||
avx_r1 = _mm256_loadu_ps(var_ptr);
|
||||
avx_r2 = _mm256_sub_ps(avx_r1, avx_r0);
|
||||
_mm256_storeu_ps(var_ptr, avx_r2);
|
||||
var_r = _mm256_loadu_ps(var_ptr);
|
||||
var_r = _mm256_sub_ps(var_r, avx_r2);
|
||||
_mm256_storeu_ps(var_ptr, var_r);
|
||||
} else {
|
||||
avx_r0 = _mm256_set1_ps(lr);
|
||||
avx_r1 = _mm256_mul_ps(avx_r3, avx_r0);
|
||||
|
||||
avx_r0 = _mm256_set1_ps(epsilon);
|
||||
avx_r2 = _mm256_sqrt_ps(avx_r5);
|
||||
avx_r4 = _mm256_add_ps(avx_r0, avx_r2);
|
||||
|
||||
avx_r0 = _mm256_div_ps(avx_r1, avx_r4);
|
||||
avx_r1 = _mm256_loadu_ps(var_ptr);
|
||||
avx_r3 = _mm256_sub_ps(avx_r1, avx_r0);
|
||||
_mm256_storeu_ps(var_ptr, avx_r3);
|
||||
avx_r0 = _mm256_mul_ps(lr_r, m_r);
|
||||
avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r);
|
||||
__m256 avx_r2 = _mm256_div_ps(avx_r0, avx_r1);
|
||||
var_r = _mm256_loadu_ps(var_ptr);
|
||||
var_r = _mm256_sub_ps(var_r, avx_r2);
|
||||
_mm256_storeu_ps(var_ptr, var_r);
|
||||
}
|
||||
m_ptr += C8NUM;
|
||||
v_ptr += C8NUM;
|
||||
var_ptr += C8NUM;
|
||||
gradient_ptr += C8NUM;
|
||||
grad_ptr += C8NUM;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -106,62 +98,50 @@ int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float
|
|||
const float *gradient, size_t start, size_t end, bool use_nesterov) {
|
||||
size_t c1 = start;
|
||||
#ifdef ENABLE_AVX
|
||||
float coeff1 = 1 - beta1;
|
||||
float coeff2 = 1 - beta2;
|
||||
float *m_ptr = m;
|
||||
float *v_ptr = v;
|
||||
float *delta_ptr = delta;
|
||||
const float *gradient_ptr = gradient;
|
||||
size_t c8 = ((end - start) / C8NUM) * C8NUM;
|
||||
__m256 coeff1_r = _mm256_set1_ps(1.0f - beta1);
|
||||
__m256 coeff2_r = _mm256_set1_ps(1.0f - beta2);
|
||||
__m256 beta1_r = _mm256_set1_ps(beta1);
|
||||
__m256 beta2_r = _mm256_set1_ps(beta2);
|
||||
__m256 lr_r = _mm256_set1_ps(-lr);
|
||||
__m256 epsi_r = _mm256_set1_ps(epsilon);
|
||||
|
||||
__m256 gradient_r0, m_r1, v_r2, beta1_r3, beta2_r4, var_r5, var_r6, var_r7;
|
||||
for (; c1 < c8 + start; c1 += C8NUM) {
|
||||
gradient_r0 = _mm256_loadu_ps(gradient_ptr); // static
|
||||
beta1_r3 = _mm256_set1_ps(beta1); // static
|
||||
var_r5 = _mm256_loadu_ps(m_ptr);
|
||||
var_r6 = _mm256_mul_ps(beta1_r3, var_r5); // m[i] = m[i] * beta1
|
||||
var_r7 = _mm256_set1_ps(coeff1);
|
||||
var_r5 = _mm256_mul_ps(var_r7, gradient_r0); //
|
||||
m_r1 = _mm256_add_ps(var_r6, var_r5);
|
||||
_mm256_storeu_ps(m_ptr, m_r1);
|
||||
float *m_ptr = m + start;
|
||||
float *v_ptr = v + start;
|
||||
float *delta_ptr = delta + start;
|
||||
const float *gradient_ptr = gradient + start;
|
||||
|
||||
beta2_r4 = _mm256_set1_ps(beta2); // static
|
||||
var_r5 = _mm256_loadu_ps(v_ptr);
|
||||
var_r6 = _mm256_mul_ps(beta2_r4, var_r5); // v[i] * beta2
|
||||
var_r7 = _mm256_set1_ps(coeff2);
|
||||
var_r5 = _mm256_mul_ps(var_r7, gradient_r0);
|
||||
var_r7 = _mm256_mul_ps(var_r5, gradient_r0);
|
||||
v_r2 = _mm256_add_ps(var_r7, var_r6);
|
||||
_mm256_storeu_ps(v_ptr, v_r2);
|
||||
__m256 m_r, v_r, delta_r, grad_r;
|
||||
__m256 avx_r0, avx_r1;
|
||||
for (; c1 < start + c8; c1 += C8NUM) {
|
||||
m_r = _mm256_loadu_ps(m_ptr);
|
||||
avx_r0 = _mm256_mul_ps(m_r, beta1_r);
|
||||
grad_r = _mm256_loadu_ps(gradient_ptr);
|
||||
m_r = _mm256_add_ps(avx_r0, _mm256_mul_ps(coeff1_r, grad_r));
|
||||
_mm256_storeu_ps(m_ptr, m_r);
|
||||
|
||||
v_r = _mm256_loadu_ps(v_ptr);
|
||||
avx_r0 = _mm256_mul_ps(v_r, beta2_r);
|
||||
avx_r1 = _mm256_mul_ps(_mm256_mul_ps(coeff2_r, grad_r), grad_r);
|
||||
v_r = _mm256_add_ps(avx_r0, avx_r1);
|
||||
_mm256_storeu_ps(v_ptr, v_r);
|
||||
|
||||
if (use_nesterov) {
|
||||
var_r5 = _mm256_mul_ps(beta1_r3, m_r1);
|
||||
var_r6 = _mm256_set1_ps(coeff1);
|
||||
var_r7 = _mm256_mul_ps(gradient_r0, var_r6);
|
||||
var_r6 = _mm256_add_ps(var_r5, var_r7); // m[i] * beta1 + (1 - beta1) * grad[i]
|
||||
var_r5 = _mm256_set1_ps(lr);
|
||||
var_r7 = _mm256_mul_ps(var_r6, var_r5);
|
||||
|
||||
var_r5 = _mm256_set1_ps(epsilon);
|
||||
var_r6 = _mm256_sqrt_ps(v_r2);
|
||||
v_r2 = _mm256_add_ps(var_r5, var_r6);
|
||||
var_r5 = _mm256_div_ps(var_r7, v_r2);
|
||||
var_r6 = _mm256_set1_ps(0.f);
|
||||
var_r7 = _mm256_sub_ps(var_r6, var_r5);
|
||||
_mm256_storeu_ps(delta_ptr, var_r7);
|
||||
avx_r0 = _mm256_add_ps(_mm256_mul_ps(m_r, beta1_r), _mm256_mul_ps(coeff1_r, grad_r));
|
||||
avx_r0 = _mm256_mul_ps(lr_r, avx_r0);
|
||||
avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r);
|
||||
delta_r = _mm256_div_ps(avx_r0, avx_r1);
|
||||
_mm256_storeu_ps(delta_ptr, delta_r);
|
||||
} else {
|
||||
var_r5 = _mm256_set1_ps(lr);
|
||||
var_r6 = _mm256_mul_ps(var_r5, m_r1);
|
||||
|
||||
var_r7 = _mm256_set1_ps(epsilon);
|
||||
var_r5 = _mm256_sqrt_ps(v_r2);
|
||||
v_r2 = _mm256_add_ps(var_r5, var_r7);
|
||||
|
||||
var_r5 = _mm256_div_ps(var_r6, v_r2);
|
||||
var_r6 = _mm256_set1_ps(0.f);
|
||||
var_r7 = _mm256_sub_ps(var_r6, var_r5);
|
||||
_mm256_storeu_ps(delta_ptr, var_r7);
|
||||
avx_r0 = _mm256_mul_ps(lr_r, m_r);
|
||||
avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r);
|
||||
delta_r = _mm256_div_ps(avx_r0, avx_r1);
|
||||
_mm256_storeu_ps(delta_ptr, delta_r);
|
||||
}
|
||||
m_ptr += C8NUM;
|
||||
v_ptr += C8NUM;
|
||||
delta_ptr += C8NUM;
|
||||
gradient_ptr += C8NUM;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
@ -29,10 +29,10 @@ int RMSPropUnuseCenterFp32(float *variable, float *mean_square, float *moment, f
|
|||
size_t c1 = start;
|
||||
#ifdef ENABLE_AVX
|
||||
size_t c8 = ((end - start) / C8NUM) * C8NUM;
|
||||
float *variable_ptr = variable;
|
||||
float *mean_square_ptr = mean_square;
|
||||
float *gradients_ptr = gradients;
|
||||
float *moment_ptr = moment;
|
||||
float *variable_ptr = variable + start;
|
||||
float *mean_square_ptr = mean_square + start;
|
||||
float *gradients_ptr = gradients + start;
|
||||
float *moment_ptr = moment + start;
|
||||
|
||||
__m256 decay_r = _mm256_set1_ps(1.0 - decay);
|
||||
__m256 momentum_r = _mm256_set1_ps(momentum);
|
||||
|
@ -78,11 +78,11 @@ int RMSPropUseCenterFp32(float *variable, float *mean_square, float *moment, flo
|
|||
size_t c1 = start;
|
||||
#ifdef ENABLE_AVX
|
||||
size_t c8 = ((end - start) / C8NUM) * C8NUM;
|
||||
float *variable_ptr = variable;
|
||||
float *mean_gradients_ptr = mean_gradients;
|
||||
float *mean_square_ptr = mean_square;
|
||||
float *moment_ptr = moment;
|
||||
float *gradients_ptr = gradients;
|
||||
float *variable_ptr = variable + start;
|
||||
float *mean_gradients_ptr = mean_gradients + start;
|
||||
float *mean_square_ptr = mean_square + start;
|
||||
float *moment_ptr = moment + start;
|
||||
float *gradients_ptr = gradients + start;
|
||||
|
||||
__m256 decay_r = _mm256_set1_ps(1.0 - decay);
|
||||
__m256 momentum_r = _mm256_set1_ps(momentum);
|
||||
|
|
Loading…
Reference in New Issue