fixes SGD

This commit is contained in:
huangbo77 2021-07-14 16:26:19 +08:00
parent 9523d28536
commit cd749b90f1
1 changed files with 4 additions and 2 deletions

View File

@ -47,7 +47,7 @@ void SGDCPUKernel<T>::CheckParam(const std::vector<AddressPtr> &inputs, const st
}
template <typename T>
bool SGDCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> & /*workspace*/,
bool SGDCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
CheckParam(inputs, outputs);
@ -57,7 +57,8 @@ bool SGDCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::v
auto accum = reinterpret_cast<T *>(inputs[3]->addr);
auto momentum = reinterpret_cast<T *>(inputs[4]->addr);
auto stat = reinterpret_cast<T *>(inputs[5]->addr);
size_t elem_num = inputs[0]->size / sizeof(float);
auto output_param = reinterpret_cast<T *>(outputs[0]->addr);
size_t elem_num = inputs[0]->size / sizeof(T);
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
@ -79,6 +80,7 @@ bool SGDCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::v
}
}
param[i] -= lr[0] * grad_new;
output_param[i] = param[i];
}
};
CPUKernelUtils::ParallelFor(task, elem_num);