fixes SGD
This commit is contained in:
parent
9523d28536
commit
cd749b90f1
|
@ -47,7 +47,7 @@ void SGDCPUKernel<T>::CheckParam(const std::vector<AddressPtr> &inputs, const st
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
bool SGDCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> & /*workspace*/,
|
||||
bool SGDCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CheckParam(inputs, outputs);
|
||||
|
||||
|
@ -57,7 +57,8 @@ bool SGDCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::v
|
|||
auto accum = reinterpret_cast<T *>(inputs[3]->addr);
|
||||
auto momentum = reinterpret_cast<T *>(inputs[4]->addr);
|
||||
auto stat = reinterpret_cast<T *>(inputs[5]->addr);
|
||||
size_t elem_num = inputs[0]->size / sizeof(float);
|
||||
auto output_param = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
size_t elem_num = inputs[0]->size / sizeof(T);
|
||||
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
|
@ -79,6 +80,7 @@ bool SGDCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::v
|
|||
}
|
||||
}
|
||||
param[i] -= lr[0] * grad_new;
|
||||
output_param[i] = param[i];
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, elem_num);
|
||||
|
|
Loading…
Reference in New Issue