!18387 fix optimizer bug

Merge pull request !18387 from wangyanling/optimizer
This commit is contained in:
i-robot 2021-06-17 21:45:10 +08:00 committed by Gitee
commit 1a694111e5
2 changed files with 54 additions and 63 deletions

View File

@ -30,7 +30,7 @@ constexpr size_t kOutputSize = 2;
void ApplyAdagradCPUKernel::InitKernel(const CNodePtr &kernel_node) { void ApplyAdagradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
update_slots_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "update_slots"); update_slots_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "update_slots");
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
} }
bool ApplyAdagradCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, bool ApplyAdagradCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,

View File

@ -28,39 +28,35 @@ int RMSPropUnuseCenterFp32(float *variable, float *mean_square, float *moment, f
float learning_rate, float decay, float epsilon, size_t start, size_t end) { float learning_rate, float decay, float epsilon, size_t start, size_t end) {
size_t c1 = start; size_t c1 = start;
#ifdef ENABLE_AVX #ifdef ENABLE_AVX
size_t c8 = ((end - start) / C8NUM) * C8NUM;
float *variable_ptr = variable; float *variable_ptr = variable;
float *mean_square_ptr = mean_square; float *mean_square_ptr = mean_square;
float *gradients_ptr = gradients; float *gradients_ptr = gradients;
float *moment_ptr = moment; float *moment_ptr = moment;
float decay_v = 1.0 - decay;
size_t c8 = ((end - start) / C8NUM) * C8NUM; __m256 decay_r = _mm256_set1_ps(1.0 - decay);
__m256 decay_r = _mm256_set1_ps(decay_v); __m256 momentum_r = _mm256_set1_ps(momentum);
__m256 moment_r = _mm256_set1_ps(momentum);
__m256 lr_r = _mm256_set1_ps(learning_rate); __m256 lr_r = _mm256_set1_ps(learning_rate);
__m256 gradient_r, mean_square_r, tmp_r1, tmp_r2, tmp_r3; __m256 epsi_r = _mm256_set1_ps(epsilon);
for (; c1 < c8; c1 += C8NUM) { __m256 gradient_r, mean_square_r, moment_r, variable_r, avx_r1, avx_r2;
for (; c1 < start + c8; c1 += C8NUM) {
gradient_r = _mm256_loadu_ps(gradients_ptr); gradient_r = _mm256_loadu_ps(gradients_ptr);
tmp_r1 = _mm256_mul_ps(gradient_r, gradient_r); mean_square_r = _mm256_loadu_ps(mean_square_ptr);
tmp_r2 = _mm256_loadu_ps(mean_square_ptr); avx_r1 = _mm256_sub_ps(_mm256_mul_ps(gradient_r, gradient_r), mean_square_r);
tmp_r3 = _mm256_sub_ps(tmp_r1, tmp_r2); avx_r2 = _mm256_mul_ps(avx_r1, decay_r);
tmp_r1 = _mm256_mul_ps(decay_r, tmp_r3); mean_square_r = _mm256_add_ps(mean_square_r, avx_r2);
mean_square_r = _mm256_add_ps(tmp_r2, tmp_r1);
_mm256_storeu_ps(mean_square_ptr, mean_square_r); _mm256_storeu_ps(mean_square_ptr, mean_square_r);
tmp_r1 = _mm256_set1_ps(epsilon); avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(mean_square_r), epsi_r);
tmp_r2 = _mm256_add_ps(mean_square_r, tmp_r1); avx_r2 = _mm256_div_ps(_mm256_mul_ps(gradient_r, lr_r), avx_r1);
tmp_r1 = _mm256_sqrt_ps(tmp_r2);
tmp_r2 = _mm256_mul_ps(gradient_r, lr_r);
tmp_r3 = _mm256_div_ps(tmp_r2, tmp_r1);
tmp_r1 = _mm256_loadu_ps(moment_ptr);
tmp_r2 = _mm256_mul_ps(tmp_r1, moment_r);
tmp_r3 = _mm256_add_ps(tmp_r2, tmp_r3);
_mm256_storeu_ps(moment_ptr, tmp_r3);
tmp_r1 = _mm256_loadu_ps(variable_ptr); moment_r = _mm256_loadu_ps(moment_ptr);
tmp_r2 = _mm256_sub_ps(tmp_r1, tmp_r3); avx_r1 = _mm256_add_ps(_mm256_mul_ps(moment_r, momentum_r), avx_r2);
_mm256_storeu_ps(variable_ptr, tmp_r2); _mm256_storeu_ps(moment_ptr, avx_r1);
variable_r = _mm256_loadu_ps(variable_ptr);
variable_r = _mm256_sub_ps(variable_r, avx_r1);
_mm256_storeu_ps(variable_ptr, variable_r);
gradients_ptr += C8NUM; gradients_ptr += C8NUM;
mean_square_ptr += C8NUM; mean_square_ptr += C8NUM;
@ -81,54 +77,49 @@ int RMSPropUseCenterFp32(float *variable, float *mean_square, float *moment, flo
float momentum, float learning_rate, float decay, float epsilon, size_t start, size_t end) { float momentum, float learning_rate, float decay, float epsilon, size_t start, size_t end) {
size_t c1 = start; size_t c1 = start;
#ifdef ENABLE_AVX #ifdef ENABLE_AVX
size_t c8 = ((end - start) / C8NUM) * C8NUM;
float *variable_ptr = variable; float *variable_ptr = variable;
float *mean_gradients_ptr = mean_gradients; float *mean_gradients_ptr = mean_gradients;
float *mean_square_ptr = mean_square; float *mean_square_ptr = mean_square;
float *moment_ptr = moment; float *moment_ptr = moment;
const float *gradients_ptr = gradients; float *gradients_ptr = gradients;
const float decay_v = 1.0f - decay;
size_t c8 = (start - end / C8NUM) * C8NUM; __m256 decay_r = _mm256_set1_ps(1.0 - decay);
__m256 gradient_r; __m256 momentum_r = _mm256_set1_ps(momentum);
__m256 var_r1, var_r2, var_r3, var_r4, var_r5, var_r6; __m256 lr_r = _mm256_set1_ps(learning_rate);
for (; c1 < c8; c1 += C8NUM) { __m256 epsi_r = _mm256_set1_ps(epsilon);
gradient_r = _mm256_loadu_ps(gradients_ptr); __m256 grad_r, mean_grad_r, mean_square_r, moment_r, variable_r;
var_r1 = _mm256_mul_ps(gradient_r, gradient_r); __m256 avx_r1, avx_r2;
var_r2 = _mm256_loadu_ps(mean_square_ptr); // for (; c1 < start + c8; c1 += C8NUM) {
var_r1 = _mm256_sub_ps(var_r1, var_r2); grad_r = _mm256_loadu_ps(gradients_ptr);
var_r3 = _mm256_set1_ps(decay_v); // 1 - decay ... mean_square_r = _mm256_loadu_ps(mean_square_ptr);
var_r1 = _mm256_mul_ps(var_r1, var_r3); avx_r1 = _mm256_sub_ps(_mm256_mul_ps(grad_r, grad_r), mean_square_r);
var_r1 = _mm256_add_ps(var_r2, var_r1); // mean_squasre ... avx_r2 = _mm256_mul_ps(avx_r1, decay_r);
_mm256_storeu_ps(mean_square_ptr, var_r1); mean_square_r = _mm256_add_ps(mean_square_r, avx_r2);
_mm256_storeu_ps(mean_square_ptr, mean_square_r);
var_r2 = _mm256_loadu_ps(mean_gradients_ptr); mean_grad_r = _mm256_loadu_ps(mean_gradients_ptr);
var_r4 = _mm256_sub_ps(gradient_r, var_r2); avx_r1 = _mm256_mul_ps(_mm256_sub_ps(grad_r, mean_grad_r), decay_r);
var_r3 = _mm256_mul_ps(var_r4, var_r3); mean_grad_r = _mm256_add_ps(mean_grad_r, avx_r1);
var_r2 = _mm256_add_ps(var_r2, var_r3); // mean_gradients .. _mm256_storeu_ps(mean_gradients_ptr, mean_grad_r);
_mm256_storeu_ps(mean_gradients_ptr, var_r2);
var_r3 = _mm256_mul_ps(var_r2, var_r2); avx_r1 = _mm256_sub_ps(mean_square_r, _mm256_mul_ps(mean_grad_r, mean_grad_r));
var_r3 = _mm256_sub_ps(var_r1, var_r3); __m256 denom_r = _mm256_add_ps(avx_r1, epsi_r);
var_r4 = _mm256_set1_ps(epsilon); __m256 cmp_r = _mm256_cmp_ps(denom_r, _mm256_setzero_ps(), _CMP_GE_OS);
var_r3 = _mm256_add_ps(var_r3, var_r4); // denom ... __m256 gt_zero_r = _mm256_blendv_ps(_mm256_set1_ps(1.0f), denom_r, cmp_r);
var_r5 = _mm256_setzero_ps();
var_r1 = _mm256_cmp_ps(var_r3, var_r5, _CMP_GE_OS); // mask_r
var_r4 = _mm256_set1_ps(learning_rate); avx_r1 = _mm256_mul_ps(grad_r, lr_r);
var_r5 = _mm256_mul_ps(gradient_r, var_r4); avx_r2 = _mm256_div_ps(avx_r1, _mm256_sqrt_ps(gt_zero_r));
var_r4 = _mm256_sqrt_ps(var_r3); moment_r = _mm256_loadu_ps(moment_ptr);
var_r6 = _mm256_div_ps(var_r5, var_r4); // (gradients[i] * learning_rate[0]) / sqrt(denom) avx_r1 = _mm256_mul_ps(moment_r, momentum_r);
var_r4 = _mm256_loadu_ps(moment_ptr); // .... avx_r1 = _mm256_add_ps(avx_r1, avx_r2);
var_r5 = _mm256_set1_ps(momentum); moment_r = _mm256_blendv_ps(moment_r, avx_r1, cmp_r);
var_r2 = _mm256_mul_ps(var_r4, var_r5); // moment[i] * momentum[i] _mm256_storeu_ps(moment_ptr, moment_r);
var_r3 = _mm256_add_ps(var_r6, var_r2);
var_r4 = _mm256_blendv_ps(var_r4, var_r3, var_r1);
_mm256_storeu_ps(moment_ptr, var_r4);
var_r2 = _mm256_loadu_ps(variable_ptr); variable_r = _mm256_loadu_ps(variable_ptr);
var_r5 = _mm256_sub_ps(var_r2, var_r4); avx_r1 = _mm256_sub_ps(variable_r, moment_r);
var_r6 = _mm256_blendv_ps(var_r2, var_r5, var_r1); variable_r = _mm256_blendv_ps(variable_r, avx_r1, cmp_r);
_mm256_storeu_ps(variable_ptr, var_r6); _mm256_storeu_ps(variable_ptr, variable_r);
variable_ptr += C8NUM; variable_ptr += C8NUM;
mean_gradients_ptr += C8NUM; mean_gradients_ptr += C8NUM;