diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_adagrad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_adagrad_cpu_kernel.cc index f0b38162bb1..238b5c5e9a3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_adagrad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_adagrad_cpu_kernel.cc @@ -30,7 +30,7 @@ constexpr size_t kOutputSize = 2; void ApplyAdagradCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); update_slots_ = AnfAlgo::GetNodeAttr(kernel_node, "update_slots"); - dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); + dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); } bool ApplyAdagradCPUKernel::Launch(const std::vector &inputs, const std::vector &, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/rmsprop_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/rmsprop_fp32.c index 362709b0b26..e5fbfa1dba1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/rmsprop_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/rmsprop_fp32.c @@ -28,39 +28,35 @@ int RMSPropUnuseCenterFp32(float *variable, float *mean_square, float *moment, f float learning_rate, float decay, float epsilon, size_t start, size_t end) { size_t c1 = start; #ifdef ENABLE_AVX + size_t c8 = ((end - start) / C8NUM) * C8NUM; float *variable_ptr = variable; float *mean_square_ptr = mean_square; float *gradients_ptr = gradients; float *moment_ptr = moment; - float decay_v = 1.0 - decay; - size_t c8 = ((end - start) / C8NUM) * C8NUM; - __m256 decay_r = _mm256_set1_ps(decay_v); - __m256 moment_r = _mm256_set1_ps(momentum); + __m256 decay_r = _mm256_set1_ps(1.0 - decay); + __m256 momentum_r = _mm256_set1_ps(momentum); __m256 lr_r = _mm256_set1_ps(learning_rate); - __m256 gradient_r, mean_square_r, tmp_r1, tmp_r2, tmp_r3; - for (; c1 < c8; c1 += C8NUM) { + __m256 epsi_r = _mm256_set1_ps(epsilon); + __m256 gradient_r, mean_square_r, moment_r, variable_r, avx_r1, avx_r2; + for (; c1 < start + c8; c1 += C8NUM) { gradient_r = _mm256_loadu_ps(gradients_ptr); - tmp_r1 = _mm256_mul_ps(gradient_r, gradient_r); - tmp_r2 = _mm256_loadu_ps(mean_square_ptr); - tmp_r3 = _mm256_sub_ps(tmp_r1, tmp_r2); - tmp_r1 = _mm256_mul_ps(decay_r, tmp_r3); - mean_square_r = _mm256_add_ps(tmp_r2, tmp_r1); + mean_square_r = _mm256_loadu_ps(mean_square_ptr); + avx_r1 = _mm256_sub_ps(_mm256_mul_ps(gradient_r, gradient_r), mean_square_r); + avx_r2 = _mm256_mul_ps(avx_r1, decay_r); + mean_square_r = _mm256_add_ps(mean_square_r, avx_r2); _mm256_storeu_ps(mean_square_ptr, mean_square_r); - tmp_r1 = _mm256_set1_ps(epsilon); - tmp_r2 = _mm256_add_ps(mean_square_r, tmp_r1); - tmp_r1 = _mm256_sqrt_ps(tmp_r2); - tmp_r2 = _mm256_mul_ps(gradient_r, lr_r); - tmp_r3 = _mm256_div_ps(tmp_r2, tmp_r1); - tmp_r1 = _mm256_loadu_ps(moment_ptr); - tmp_r2 = _mm256_mul_ps(tmp_r1, moment_r); - tmp_r3 = _mm256_add_ps(tmp_r2, tmp_r3); - _mm256_storeu_ps(moment_ptr, tmp_r3); + avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(mean_square_r), epsi_r); + avx_r2 = _mm256_div_ps(_mm256_mul_ps(gradient_r, lr_r), avx_r1); - tmp_r1 = _mm256_loadu_ps(variable_ptr); - tmp_r2 = _mm256_sub_ps(tmp_r1, tmp_r3); - _mm256_storeu_ps(variable_ptr, tmp_r2); + moment_r = _mm256_loadu_ps(moment_ptr); + avx_r1 = _mm256_add_ps(_mm256_mul_ps(moment_r, momentum_r), avx_r2); + _mm256_storeu_ps(moment_ptr, avx_r1); + + variable_r = _mm256_loadu_ps(variable_ptr); + variable_r = _mm256_sub_ps(variable_r, avx_r1); + _mm256_storeu_ps(variable_ptr, variable_r); gradients_ptr += C8NUM; mean_square_ptr += C8NUM; @@ -81,54 +77,49 @@ int RMSPropUseCenterFp32(float *variable, float *mean_square, float *moment, flo float momentum, float learning_rate, float decay, float epsilon, size_t start, size_t end) { size_t c1 = start; #ifdef ENABLE_AVX + size_t c8 = ((end - start) / C8NUM) * C8NUM; float *variable_ptr = variable; float *mean_gradients_ptr = mean_gradients; float *mean_square_ptr = mean_square; float *moment_ptr = moment; - const float *gradients_ptr = gradients; - const float decay_v = 1.0f - decay; + float *gradients_ptr = gradients; - size_t c8 = (start - end / C8NUM) * C8NUM; - __m256 gradient_r; - __m256 var_r1, var_r2, var_r3, var_r4, var_r5, var_r6; - for (; c1 < c8; c1 += C8NUM) { - gradient_r = _mm256_loadu_ps(gradients_ptr); - var_r1 = _mm256_mul_ps(gradient_r, gradient_r); - var_r2 = _mm256_loadu_ps(mean_square_ptr); // - var_r1 = _mm256_sub_ps(var_r1, var_r2); - var_r3 = _mm256_set1_ps(decay_v); // 1 - decay ... - var_r1 = _mm256_mul_ps(var_r1, var_r3); - var_r1 = _mm256_add_ps(var_r2, var_r1); // mean_squasre ... - _mm256_storeu_ps(mean_square_ptr, var_r1); + __m256 decay_r = _mm256_set1_ps(1.0 - decay); + __m256 momentum_r = _mm256_set1_ps(momentum); + __m256 lr_r = _mm256_set1_ps(learning_rate); + __m256 epsi_r = _mm256_set1_ps(epsilon); + __m256 grad_r, mean_grad_r, mean_square_r, moment_r, variable_r; + __m256 avx_r1, avx_r2; + for (; c1 < start + c8; c1 += C8NUM) { + grad_r = _mm256_loadu_ps(gradients_ptr); + mean_square_r = _mm256_loadu_ps(mean_square_ptr); + avx_r1 = _mm256_sub_ps(_mm256_mul_ps(grad_r, grad_r), mean_square_r); + avx_r2 = _mm256_mul_ps(avx_r1, decay_r); + mean_square_r = _mm256_add_ps(mean_square_r, avx_r2); + _mm256_storeu_ps(mean_square_ptr, mean_square_r); - var_r2 = _mm256_loadu_ps(mean_gradients_ptr); - var_r4 = _mm256_sub_ps(gradient_r, var_r2); - var_r3 = _mm256_mul_ps(var_r4, var_r3); - var_r2 = _mm256_add_ps(var_r2, var_r3); // mean_gradients .. - _mm256_storeu_ps(mean_gradients_ptr, var_r2); + mean_grad_r = _mm256_loadu_ps(mean_gradients_ptr); + avx_r1 = _mm256_mul_ps(_mm256_sub_ps(grad_r, mean_grad_r), decay_r); + mean_grad_r = _mm256_add_ps(mean_grad_r, avx_r1); + _mm256_storeu_ps(mean_gradients_ptr, mean_grad_r); - var_r3 = _mm256_mul_ps(var_r2, var_r2); - var_r3 = _mm256_sub_ps(var_r1, var_r3); - var_r4 = _mm256_set1_ps(epsilon); - var_r3 = _mm256_add_ps(var_r3, var_r4); // denom ... - var_r5 = _mm256_setzero_ps(); - var_r1 = _mm256_cmp_ps(var_r3, var_r5, _CMP_GE_OS); // mask_r + avx_r1 = _mm256_sub_ps(mean_square_r, _mm256_mul_ps(mean_grad_r, mean_grad_r)); + __m256 denom_r = _mm256_add_ps(avx_r1, epsi_r); + __m256 cmp_r = _mm256_cmp_ps(denom_r, _mm256_setzero_ps(), _CMP_GE_OS); + __m256 gt_zero_r = _mm256_blendv_ps(_mm256_set1_ps(1.0f), denom_r, cmp_r); - var_r4 = _mm256_set1_ps(learning_rate); - var_r5 = _mm256_mul_ps(gradient_r, var_r4); - var_r4 = _mm256_sqrt_ps(var_r3); - var_r6 = _mm256_div_ps(var_r5, var_r4); // (gradients[i] * learning_rate[0]) / sqrt(denom) - var_r4 = _mm256_loadu_ps(moment_ptr); // .... - var_r5 = _mm256_set1_ps(momentum); - var_r2 = _mm256_mul_ps(var_r4, var_r5); // moment[i] * momentum[i] - var_r3 = _mm256_add_ps(var_r6, var_r2); - var_r4 = _mm256_blendv_ps(var_r4, var_r3, var_r1); - _mm256_storeu_ps(moment_ptr, var_r4); + avx_r1 = _mm256_mul_ps(grad_r, lr_r); + avx_r2 = _mm256_div_ps(avx_r1, _mm256_sqrt_ps(gt_zero_r)); + moment_r = _mm256_loadu_ps(moment_ptr); + avx_r1 = _mm256_mul_ps(moment_r, momentum_r); + avx_r1 = _mm256_add_ps(avx_r1, avx_r2); + moment_r = _mm256_blendv_ps(moment_r, avx_r1, cmp_r); + _mm256_storeu_ps(moment_ptr, moment_r); - var_r2 = _mm256_loadu_ps(variable_ptr); - var_r5 = _mm256_sub_ps(var_r2, var_r4); - var_r6 = _mm256_blendv_ps(var_r2, var_r5, var_r1); - _mm256_storeu_ps(variable_ptr, var_r6); + variable_r = _mm256_loadu_ps(variable_ptr); + avx_r1 = _mm256_sub_ps(variable_r, moment_r); + variable_r = _mm256_blendv_ps(variable_r, avx_r1, cmp_r); + _mm256_storeu_ps(variable_ptr, variable_r); variable_ptr += C8NUM; mean_gradients_ptr += C8NUM;