!30403 update cpu adafactor code annotation

Merge pull request !30403 from kisnwang/r1.6
This commit is contained in:
i-robot 2022-02-23 01:25:23 +00:00 committed by Gitee
commit 488dc38df0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 5 additions and 5 deletions

View File

@ -109,7 +109,7 @@ void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressP
size_t last_row_col_size = last_row_dim_size_ * last_col_dim_size_;
size_t row_dim_size = last_row_dim_size_;
size_t col_dim_size = last_col_dim_size_;
// exp_avg_sq_row = exp_avg_sq_row * beta2t + reduce_mean(update, -1) * one_minus_beta2t;
// step 1: exp_avg_sq_row = exp_avg_sq_row * beta2t + reduce_mean(update, -1) * one_minus_beta2t;
task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
float row_reduce = 0;
@ -124,7 +124,7 @@ void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressP
};
CPUKernelUtils::ParallelFor(task, exp_avg_sq_row_elem_num, kBatchSize);
// r_factor = sqrt(exp_avg_sq_row / reduce_mean(exp_avg_sq_row, -1))
// step 2: r_factor = sqrt(exp_avg_sq_row / reduce_mean(exp_avg_sq_row, -1))
task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
float col_reduce = 0;
@ -142,8 +142,8 @@ void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressP
};
CPUKernelUtils::ParallelFor(task, exp_avg_sq_row_elem_num / col_dim_size, kBatchSize);
// exp_avg_sq_col = exp_avg_sq_col * beta2t + reduce_mean(update, -2) * one_minus_beta2t;
// c_factor = sqrt(exp_avg_sq_col);
// step 3: exp_avg_sq_col = exp_avg_sq_col * beta2t + reduce_mean(update, -2) * one_minus_beta2t;
// step 4: c_factor = sqrt(exp_avg_sq_col);
task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
float row_reduce = 0;
@ -160,7 +160,7 @@ void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressP
};
CPUKernelUtils::ParallelFor(task, exp_avg_sq_col_elem_num, kBatchSize);
// update = grad / (r_factor * c_factor);
// step 5: update = grad / (r_factor * c_factor);
task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
size_t row_i = i % row_dim_size;