forked from mindspore-Ecosystem/mindspore
Fix: ApplyAdagrad CodeDEX.
This commit is contained in:
parent
823ba8d71f
commit
8ba9987460
|
@ -84,6 +84,7 @@ void ApplyAdagradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs)
|
|||
|
||||
if (batch_size == 0) {
|
||||
MS_LOG(EXCEPTION) << "Error occur in launch kernel";
|
||||
return;
|
||||
}
|
||||
while (start < length) {
|
||||
size_t end = (start + batch_size) > length ? length : (start + batch_size);
|
||||
|
@ -98,7 +99,8 @@ void ApplyAdagradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs)
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ApplyAdagradCPUKernel::LaunchApplyAdagrad(T var, T accum, T lr, T gradient, size_t start, size_t end) {
|
||||
void ApplyAdagradCPUKernel::LaunchApplyAdagrad(T const var, T const accum, const T lr, const T gradient, size_t start,
|
||||
size_t end) {
|
||||
// DataType can only be float32 or float16, so eps will not be zero.
|
||||
using DataType = typename std::iterator_traits<T>::value_type;
|
||||
const DataType one = DataType(1);
|
||||
|
|
|
@ -38,7 +38,7 @@ class ApplyAdagradCPUKernel : public CPUKernel {
|
|||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs);
|
||||
template <typename T>
|
||||
void LaunchApplyAdagrad(T var, T accum, T lr, T gradient, size_t start, size_t end);
|
||||
void LaunchApplyAdagrad(T const var, T const accum, const T lr, const T gradient, size_t start, size_t end);
|
||||
bool update_slots_{true};
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue