forked from mindspore-Ecosystem/mindspore
!18387 fix optimizer bug
Merge pull request !18387 from wangyanling/optimizer
This commit is contained in:
commit
1a694111e5
|
@ -30,7 +30,7 @@ constexpr size_t kOutputSize = 2;
|
||||||
void ApplyAdagradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
void ApplyAdagradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
update_slots_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "update_slots");
|
update_slots_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "update_slots");
|
||||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ApplyAdagradCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
bool ApplyAdagradCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
|
|
|
@ -28,39 +28,35 @@ int RMSPropUnuseCenterFp32(float *variable, float *mean_square, float *moment, f
|
||||||
float learning_rate, float decay, float epsilon, size_t start, size_t end) {
|
float learning_rate, float decay, float epsilon, size_t start, size_t end) {
|
||||||
size_t c1 = start;
|
size_t c1 = start;
|
||||||
#ifdef ENABLE_AVX
|
#ifdef ENABLE_AVX
|
||||||
|
size_t c8 = ((end - start) / C8NUM) * C8NUM;
|
||||||
float *variable_ptr = variable;
|
float *variable_ptr = variable;
|
||||||
float *mean_square_ptr = mean_square;
|
float *mean_square_ptr = mean_square;
|
||||||
float *gradients_ptr = gradients;
|
float *gradients_ptr = gradients;
|
||||||
float *moment_ptr = moment;
|
float *moment_ptr = moment;
|
||||||
float decay_v = 1.0 - decay;
|
|
||||||
|
|
||||||
size_t c8 = ((end - start) / C8NUM) * C8NUM;
|
__m256 decay_r = _mm256_set1_ps(1.0 - decay);
|
||||||
__m256 decay_r = _mm256_set1_ps(decay_v);
|
__m256 momentum_r = _mm256_set1_ps(momentum);
|
||||||
__m256 moment_r = _mm256_set1_ps(momentum);
|
|
||||||
__m256 lr_r = _mm256_set1_ps(learning_rate);
|
__m256 lr_r = _mm256_set1_ps(learning_rate);
|
||||||
__m256 gradient_r, mean_square_r, tmp_r1, tmp_r2, tmp_r3;
|
__m256 epsi_r = _mm256_set1_ps(epsilon);
|
||||||
for (; c1 < c8; c1 += C8NUM) {
|
__m256 gradient_r, mean_square_r, moment_r, variable_r, avx_r1, avx_r2;
|
||||||
|
for (; c1 < start + c8; c1 += C8NUM) {
|
||||||
gradient_r = _mm256_loadu_ps(gradients_ptr);
|
gradient_r = _mm256_loadu_ps(gradients_ptr);
|
||||||
tmp_r1 = _mm256_mul_ps(gradient_r, gradient_r);
|
mean_square_r = _mm256_loadu_ps(mean_square_ptr);
|
||||||
tmp_r2 = _mm256_loadu_ps(mean_square_ptr);
|
avx_r1 = _mm256_sub_ps(_mm256_mul_ps(gradient_r, gradient_r), mean_square_r);
|
||||||
tmp_r3 = _mm256_sub_ps(tmp_r1, tmp_r2);
|
avx_r2 = _mm256_mul_ps(avx_r1, decay_r);
|
||||||
tmp_r1 = _mm256_mul_ps(decay_r, tmp_r3);
|
mean_square_r = _mm256_add_ps(mean_square_r, avx_r2);
|
||||||
mean_square_r = _mm256_add_ps(tmp_r2, tmp_r1);
|
|
||||||
_mm256_storeu_ps(mean_square_ptr, mean_square_r);
|
_mm256_storeu_ps(mean_square_ptr, mean_square_r);
|
||||||
|
|
||||||
tmp_r1 = _mm256_set1_ps(epsilon);
|
avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(mean_square_r), epsi_r);
|
||||||
tmp_r2 = _mm256_add_ps(mean_square_r, tmp_r1);
|
avx_r2 = _mm256_div_ps(_mm256_mul_ps(gradient_r, lr_r), avx_r1);
|
||||||
tmp_r1 = _mm256_sqrt_ps(tmp_r2);
|
|
||||||
tmp_r2 = _mm256_mul_ps(gradient_r, lr_r);
|
|
||||||
tmp_r3 = _mm256_div_ps(tmp_r2, tmp_r1);
|
|
||||||
tmp_r1 = _mm256_loadu_ps(moment_ptr);
|
|
||||||
tmp_r2 = _mm256_mul_ps(tmp_r1, moment_r);
|
|
||||||
tmp_r3 = _mm256_add_ps(tmp_r2, tmp_r3);
|
|
||||||
_mm256_storeu_ps(moment_ptr, tmp_r3);
|
|
||||||
|
|
||||||
tmp_r1 = _mm256_loadu_ps(variable_ptr);
|
moment_r = _mm256_loadu_ps(moment_ptr);
|
||||||
tmp_r2 = _mm256_sub_ps(tmp_r1, tmp_r3);
|
avx_r1 = _mm256_add_ps(_mm256_mul_ps(moment_r, momentum_r), avx_r2);
|
||||||
_mm256_storeu_ps(variable_ptr, tmp_r2);
|
_mm256_storeu_ps(moment_ptr, avx_r1);
|
||||||
|
|
||||||
|
variable_r = _mm256_loadu_ps(variable_ptr);
|
||||||
|
variable_r = _mm256_sub_ps(variable_r, avx_r1);
|
||||||
|
_mm256_storeu_ps(variable_ptr, variable_r);
|
||||||
|
|
||||||
gradients_ptr += C8NUM;
|
gradients_ptr += C8NUM;
|
||||||
mean_square_ptr += C8NUM;
|
mean_square_ptr += C8NUM;
|
||||||
|
@ -81,54 +77,49 @@ int RMSPropUseCenterFp32(float *variable, float *mean_square, float *moment, flo
|
||||||
float momentum, float learning_rate, float decay, float epsilon, size_t start, size_t end) {
|
float momentum, float learning_rate, float decay, float epsilon, size_t start, size_t end) {
|
||||||
size_t c1 = start;
|
size_t c1 = start;
|
||||||
#ifdef ENABLE_AVX
|
#ifdef ENABLE_AVX
|
||||||
|
size_t c8 = ((end - start) / C8NUM) * C8NUM;
|
||||||
float *variable_ptr = variable;
|
float *variable_ptr = variable;
|
||||||
float *mean_gradients_ptr = mean_gradients;
|
float *mean_gradients_ptr = mean_gradients;
|
||||||
float *mean_square_ptr = mean_square;
|
float *mean_square_ptr = mean_square;
|
||||||
float *moment_ptr = moment;
|
float *moment_ptr = moment;
|
||||||
const float *gradients_ptr = gradients;
|
float *gradients_ptr = gradients;
|
||||||
const float decay_v = 1.0f - decay;
|
|
||||||
|
|
||||||
size_t c8 = (start - end / C8NUM) * C8NUM;
|
__m256 decay_r = _mm256_set1_ps(1.0 - decay);
|
||||||
__m256 gradient_r;
|
__m256 momentum_r = _mm256_set1_ps(momentum);
|
||||||
__m256 var_r1, var_r2, var_r3, var_r4, var_r5, var_r6;
|
__m256 lr_r = _mm256_set1_ps(learning_rate);
|
||||||
for (; c1 < c8; c1 += C8NUM) {
|
__m256 epsi_r = _mm256_set1_ps(epsilon);
|
||||||
gradient_r = _mm256_loadu_ps(gradients_ptr);
|
__m256 grad_r, mean_grad_r, mean_square_r, moment_r, variable_r;
|
||||||
var_r1 = _mm256_mul_ps(gradient_r, gradient_r);
|
__m256 avx_r1, avx_r2;
|
||||||
var_r2 = _mm256_loadu_ps(mean_square_ptr); //
|
for (; c1 < start + c8; c1 += C8NUM) {
|
||||||
var_r1 = _mm256_sub_ps(var_r1, var_r2);
|
grad_r = _mm256_loadu_ps(gradients_ptr);
|
||||||
var_r3 = _mm256_set1_ps(decay_v); // 1 - decay ...
|
mean_square_r = _mm256_loadu_ps(mean_square_ptr);
|
||||||
var_r1 = _mm256_mul_ps(var_r1, var_r3);
|
avx_r1 = _mm256_sub_ps(_mm256_mul_ps(grad_r, grad_r), mean_square_r);
|
||||||
var_r1 = _mm256_add_ps(var_r2, var_r1); // mean_squasre ...
|
avx_r2 = _mm256_mul_ps(avx_r1, decay_r);
|
||||||
_mm256_storeu_ps(mean_square_ptr, var_r1);
|
mean_square_r = _mm256_add_ps(mean_square_r, avx_r2);
|
||||||
|
_mm256_storeu_ps(mean_square_ptr, mean_square_r);
|
||||||
|
|
||||||
var_r2 = _mm256_loadu_ps(mean_gradients_ptr);
|
mean_grad_r = _mm256_loadu_ps(mean_gradients_ptr);
|
||||||
var_r4 = _mm256_sub_ps(gradient_r, var_r2);
|
avx_r1 = _mm256_mul_ps(_mm256_sub_ps(grad_r, mean_grad_r), decay_r);
|
||||||
var_r3 = _mm256_mul_ps(var_r4, var_r3);
|
mean_grad_r = _mm256_add_ps(mean_grad_r, avx_r1);
|
||||||
var_r2 = _mm256_add_ps(var_r2, var_r3); // mean_gradients ..
|
_mm256_storeu_ps(mean_gradients_ptr, mean_grad_r);
|
||||||
_mm256_storeu_ps(mean_gradients_ptr, var_r2);
|
|
||||||
|
|
||||||
var_r3 = _mm256_mul_ps(var_r2, var_r2);
|
avx_r1 = _mm256_sub_ps(mean_square_r, _mm256_mul_ps(mean_grad_r, mean_grad_r));
|
||||||
var_r3 = _mm256_sub_ps(var_r1, var_r3);
|
__m256 denom_r = _mm256_add_ps(avx_r1, epsi_r);
|
||||||
var_r4 = _mm256_set1_ps(epsilon);
|
__m256 cmp_r = _mm256_cmp_ps(denom_r, _mm256_setzero_ps(), _CMP_GE_OS);
|
||||||
var_r3 = _mm256_add_ps(var_r3, var_r4); // denom ...
|
__m256 gt_zero_r = _mm256_blendv_ps(_mm256_set1_ps(1.0f), denom_r, cmp_r);
|
||||||
var_r5 = _mm256_setzero_ps();
|
|
||||||
var_r1 = _mm256_cmp_ps(var_r3, var_r5, _CMP_GE_OS); // mask_r
|
|
||||||
|
|
||||||
var_r4 = _mm256_set1_ps(learning_rate);
|
avx_r1 = _mm256_mul_ps(grad_r, lr_r);
|
||||||
var_r5 = _mm256_mul_ps(gradient_r, var_r4);
|
avx_r2 = _mm256_div_ps(avx_r1, _mm256_sqrt_ps(gt_zero_r));
|
||||||
var_r4 = _mm256_sqrt_ps(var_r3);
|
moment_r = _mm256_loadu_ps(moment_ptr);
|
||||||
var_r6 = _mm256_div_ps(var_r5, var_r4); // (gradients[i] * learning_rate[0]) / sqrt(denom)
|
avx_r1 = _mm256_mul_ps(moment_r, momentum_r);
|
||||||
var_r4 = _mm256_loadu_ps(moment_ptr); // ....
|
avx_r1 = _mm256_add_ps(avx_r1, avx_r2);
|
||||||
var_r5 = _mm256_set1_ps(momentum);
|
moment_r = _mm256_blendv_ps(moment_r, avx_r1, cmp_r);
|
||||||
var_r2 = _mm256_mul_ps(var_r4, var_r5); // moment[i] * momentum[i]
|
_mm256_storeu_ps(moment_ptr, moment_r);
|
||||||
var_r3 = _mm256_add_ps(var_r6, var_r2);
|
|
||||||
var_r4 = _mm256_blendv_ps(var_r4, var_r3, var_r1);
|
|
||||||
_mm256_storeu_ps(moment_ptr, var_r4);
|
|
||||||
|
|
||||||
var_r2 = _mm256_loadu_ps(variable_ptr);
|
variable_r = _mm256_loadu_ps(variable_ptr);
|
||||||
var_r5 = _mm256_sub_ps(var_r2, var_r4);
|
avx_r1 = _mm256_sub_ps(variable_r, moment_r);
|
||||||
var_r6 = _mm256_blendv_ps(var_r2, var_r5, var_r1);
|
variable_r = _mm256_blendv_ps(variable_r, avx_r1, cmp_r);
|
||||||
_mm256_storeu_ps(variable_ptr, var_r6);
|
_mm256_storeu_ps(variable_ptr, variable_r);
|
||||||
|
|
||||||
variable_ptr += C8NUM;
|
variable_ptr += C8NUM;
|
||||||
mean_gradients_ptr += C8NUM;
|
mean_gradients_ptr += C8NUM;
|
||||||
|
|
Loading…
Reference in New Issue