!18387 fix optimizer bug

Merge pull request !18387 from wangyanling/optimizer
This commit is contained in:
i-robot 2021-06-17 21:45:10 +08:00 committed by Gitee
commit 1a694111e5
2 changed files with 54 additions and 63 deletions

View File

@ -30,7 +30,7 @@ constexpr size_t kOutputSize = 2;
void ApplyAdagradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
update_slots_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "update_slots");
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
}
bool ApplyAdagradCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,

View File

@ -28,39 +28,35 @@ int RMSPropUnuseCenterFp32(float *variable, float *mean_square, float *moment, f
float learning_rate, float decay, float epsilon, size_t start, size_t end) {
size_t c1 = start;
#ifdef ENABLE_AVX
size_t c8 = ((end - start) / C8NUM) * C8NUM;
float *variable_ptr = variable;
float *mean_square_ptr = mean_square;
float *gradients_ptr = gradients;
float *moment_ptr = moment;
float decay_v = 1.0 - decay;
size_t c8 = ((end - start) / C8NUM) * C8NUM;
__m256 decay_r = _mm256_set1_ps(decay_v);
__m256 moment_r = _mm256_set1_ps(momentum);
__m256 decay_r = _mm256_set1_ps(1.0 - decay);
__m256 momentum_r = _mm256_set1_ps(momentum);
__m256 lr_r = _mm256_set1_ps(learning_rate);
__m256 gradient_r, mean_square_r, tmp_r1, tmp_r2, tmp_r3;
for (; c1 < c8; c1 += C8NUM) {
__m256 epsi_r = _mm256_set1_ps(epsilon);
__m256 gradient_r, mean_square_r, moment_r, variable_r, avx_r1, avx_r2;
for (; c1 < start + c8; c1 += C8NUM) {
gradient_r = _mm256_loadu_ps(gradients_ptr);
tmp_r1 = _mm256_mul_ps(gradient_r, gradient_r);
tmp_r2 = _mm256_loadu_ps(mean_square_ptr);
tmp_r3 = _mm256_sub_ps(tmp_r1, tmp_r2);
tmp_r1 = _mm256_mul_ps(decay_r, tmp_r3);
mean_square_r = _mm256_add_ps(tmp_r2, tmp_r1);
mean_square_r = _mm256_loadu_ps(mean_square_ptr);
avx_r1 = _mm256_sub_ps(_mm256_mul_ps(gradient_r, gradient_r), mean_square_r);
avx_r2 = _mm256_mul_ps(avx_r1, decay_r);
mean_square_r = _mm256_add_ps(mean_square_r, avx_r2);
_mm256_storeu_ps(mean_square_ptr, mean_square_r);
tmp_r1 = _mm256_set1_ps(epsilon);
tmp_r2 = _mm256_add_ps(mean_square_r, tmp_r1);
tmp_r1 = _mm256_sqrt_ps(tmp_r2);
tmp_r2 = _mm256_mul_ps(gradient_r, lr_r);
tmp_r3 = _mm256_div_ps(tmp_r2, tmp_r1);
tmp_r1 = _mm256_loadu_ps(moment_ptr);
tmp_r2 = _mm256_mul_ps(tmp_r1, moment_r);
tmp_r3 = _mm256_add_ps(tmp_r2, tmp_r3);
_mm256_storeu_ps(moment_ptr, tmp_r3);
avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(mean_square_r), epsi_r);
avx_r2 = _mm256_div_ps(_mm256_mul_ps(gradient_r, lr_r), avx_r1);
tmp_r1 = _mm256_loadu_ps(variable_ptr);
tmp_r2 = _mm256_sub_ps(tmp_r1, tmp_r3);
_mm256_storeu_ps(variable_ptr, tmp_r2);
moment_r = _mm256_loadu_ps(moment_ptr);
avx_r1 = _mm256_add_ps(_mm256_mul_ps(moment_r, momentum_r), avx_r2);
_mm256_storeu_ps(moment_ptr, avx_r1);
variable_r = _mm256_loadu_ps(variable_ptr);
variable_r = _mm256_sub_ps(variable_r, avx_r1);
_mm256_storeu_ps(variable_ptr, variable_r);
gradients_ptr += C8NUM;
mean_square_ptr += C8NUM;
@ -81,54 +77,49 @@ int RMSPropUseCenterFp32(float *variable, float *mean_square, float *moment, flo
float momentum, float learning_rate, float decay, float epsilon, size_t start, size_t end) {
size_t c1 = start;
#ifdef ENABLE_AVX
size_t c8 = ((end - start) / C8NUM) * C8NUM;
float *variable_ptr = variable;
float *mean_gradients_ptr = mean_gradients;
float *mean_square_ptr = mean_square;
float *moment_ptr = moment;
const float *gradients_ptr = gradients;
const float decay_v = 1.0f - decay;
float *gradients_ptr = gradients;
size_t c8 = (start - end / C8NUM) * C8NUM;
__m256 gradient_r;
__m256 var_r1, var_r2, var_r3, var_r4, var_r5, var_r6;
for (; c1 < c8; c1 += C8NUM) {
gradient_r = _mm256_loadu_ps(gradients_ptr);
var_r1 = _mm256_mul_ps(gradient_r, gradient_r);
var_r2 = _mm256_loadu_ps(mean_square_ptr); //
var_r1 = _mm256_sub_ps(var_r1, var_r2);
var_r3 = _mm256_set1_ps(decay_v); // 1 - decay ...
var_r1 = _mm256_mul_ps(var_r1, var_r3);
var_r1 = _mm256_add_ps(var_r2, var_r1); // mean_squasre ...
_mm256_storeu_ps(mean_square_ptr, var_r1);
__m256 decay_r = _mm256_set1_ps(1.0 - decay);
__m256 momentum_r = _mm256_set1_ps(momentum);
__m256 lr_r = _mm256_set1_ps(learning_rate);
__m256 epsi_r = _mm256_set1_ps(epsilon);
__m256 grad_r, mean_grad_r, mean_square_r, moment_r, variable_r;
__m256 avx_r1, avx_r2;
for (; c1 < start + c8; c1 += C8NUM) {
grad_r = _mm256_loadu_ps(gradients_ptr);
mean_square_r = _mm256_loadu_ps(mean_square_ptr);
avx_r1 = _mm256_sub_ps(_mm256_mul_ps(grad_r, grad_r), mean_square_r);
avx_r2 = _mm256_mul_ps(avx_r1, decay_r);
mean_square_r = _mm256_add_ps(mean_square_r, avx_r2);
_mm256_storeu_ps(mean_square_ptr, mean_square_r);
var_r2 = _mm256_loadu_ps(mean_gradients_ptr);
var_r4 = _mm256_sub_ps(gradient_r, var_r2);
var_r3 = _mm256_mul_ps(var_r4, var_r3);
var_r2 = _mm256_add_ps(var_r2, var_r3); // mean_gradients ..
_mm256_storeu_ps(mean_gradients_ptr, var_r2);
mean_grad_r = _mm256_loadu_ps(mean_gradients_ptr);
avx_r1 = _mm256_mul_ps(_mm256_sub_ps(grad_r, mean_grad_r), decay_r);
mean_grad_r = _mm256_add_ps(mean_grad_r, avx_r1);
_mm256_storeu_ps(mean_gradients_ptr, mean_grad_r);
var_r3 = _mm256_mul_ps(var_r2, var_r2);
var_r3 = _mm256_sub_ps(var_r1, var_r3);
var_r4 = _mm256_set1_ps(epsilon);
var_r3 = _mm256_add_ps(var_r3, var_r4); // denom ...
var_r5 = _mm256_setzero_ps();
var_r1 = _mm256_cmp_ps(var_r3, var_r5, _CMP_GE_OS); // mask_r
avx_r1 = _mm256_sub_ps(mean_square_r, _mm256_mul_ps(mean_grad_r, mean_grad_r));
__m256 denom_r = _mm256_add_ps(avx_r1, epsi_r);
__m256 cmp_r = _mm256_cmp_ps(denom_r, _mm256_setzero_ps(), _CMP_GE_OS);
__m256 gt_zero_r = _mm256_blendv_ps(_mm256_set1_ps(1.0f), denom_r, cmp_r);
var_r4 = _mm256_set1_ps(learning_rate);
var_r5 = _mm256_mul_ps(gradient_r, var_r4);
var_r4 = _mm256_sqrt_ps(var_r3);
var_r6 = _mm256_div_ps(var_r5, var_r4); // (gradients[i] * learning_rate[0]) / sqrt(denom)
var_r4 = _mm256_loadu_ps(moment_ptr); // ....
var_r5 = _mm256_set1_ps(momentum);
var_r2 = _mm256_mul_ps(var_r4, var_r5); // moment[i] * momentum[i]
var_r3 = _mm256_add_ps(var_r6, var_r2);
var_r4 = _mm256_blendv_ps(var_r4, var_r3, var_r1);
_mm256_storeu_ps(moment_ptr, var_r4);
avx_r1 = _mm256_mul_ps(grad_r, lr_r);
avx_r2 = _mm256_div_ps(avx_r1, _mm256_sqrt_ps(gt_zero_r));
moment_r = _mm256_loadu_ps(moment_ptr);
avx_r1 = _mm256_mul_ps(moment_r, momentum_r);
avx_r1 = _mm256_add_ps(avx_r1, avx_r2);
moment_r = _mm256_blendv_ps(moment_r, avx_r1, cmp_r);
_mm256_storeu_ps(moment_ptr, moment_r);
var_r2 = _mm256_loadu_ps(variable_ptr);
var_r5 = _mm256_sub_ps(var_r2, var_r4);
var_r6 = _mm256_blendv_ps(var_r2, var_r5, var_r1);
_mm256_storeu_ps(variable_ptr, var_r6);
variable_r = _mm256_loadu_ps(variable_ptr);
avx_r1 = _mm256_sub_ps(variable_r, moment_r);
variable_r = _mm256_blendv_ps(variable_r, avx_r1, cmp_r);
_mm256_storeu_ps(variable_ptr, variable_r);
variable_ptr += C8NUM;
mean_gradients_ptr += C8NUM;