!30403 update cpu adafactor code annotation
Merge pull request !30403 from kisnwang/r1.6
This commit is contained in:
commit
488dc38df0
|
@ -109,7 +109,7 @@ void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressP
|
|||
size_t last_row_col_size = last_row_dim_size_ * last_col_dim_size_;
|
||||
size_t row_dim_size = last_row_dim_size_;
|
||||
size_t col_dim_size = last_col_dim_size_;
|
||||
// exp_avg_sq_row = exp_avg_sq_row * beta2t + reduce_mean(update, -1) * one_minus_beta2t;
|
||||
// step 1: exp_avg_sq_row = exp_avg_sq_row * beta2t + reduce_mean(update, -1) * one_minus_beta2t;
|
||||
task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
float row_reduce = 0;
|
||||
|
@ -124,7 +124,7 @@ void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressP
|
|||
};
|
||||
CPUKernelUtils::ParallelFor(task, exp_avg_sq_row_elem_num, kBatchSize);
|
||||
|
||||
// r_factor = sqrt(exp_avg_sq_row / reduce_mean(exp_avg_sq_row, -1))
|
||||
// step 2: r_factor = sqrt(exp_avg_sq_row / reduce_mean(exp_avg_sq_row, -1))
|
||||
task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
float col_reduce = 0;
|
||||
|
@ -142,8 +142,8 @@ void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressP
|
|||
};
|
||||
CPUKernelUtils::ParallelFor(task, exp_avg_sq_row_elem_num / col_dim_size, kBatchSize);
|
||||
|
||||
// exp_avg_sq_col = exp_avg_sq_col * beta2t + reduce_mean(update, -2) * one_minus_beta2t;
|
||||
// c_factor = sqrt(exp_avg_sq_col);
|
||||
// step 3: exp_avg_sq_col = exp_avg_sq_col * beta2t + reduce_mean(update, -2) * one_minus_beta2t;
|
||||
// step 4: c_factor = sqrt(exp_avg_sq_col);
|
||||
task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
float row_reduce = 0;
|
||||
|
@ -160,7 +160,7 @@ void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressP
|
|||
};
|
||||
CPUKernelUtils::ParallelFor(task, exp_avg_sq_col_elem_num, kBatchSize);
|
||||
|
||||
// update = grad / (r_factor * c_factor);
|
||||
// step 5: update = grad / (r_factor * c_factor);
|
||||
task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
size_t row_i = i % row_dim_size;
|
||||
|
|
Loading…
Reference in New Issue