!24033 fix cpu rmsprop

Merge pull request !24033 from 范吉斌/fix_bug
This commit is contained in:
i-robot 2021-09-26 06:55:01 +00:00 committed by Gitee
commit 93ea5f5992
2 changed files with 4 additions and 4 deletions

View File

@ -98,7 +98,7 @@ template <typename T>
bool RMSPropCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &) {
if (!use_center_) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCenteredRMSPropInputsNum, kernel_name_);
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kRMSPropInputsNum, kernel_name_);
float *variable = reinterpret_cast<float *>(inputs[0]->addr);
float *mean_square = reinterpret_cast<float *>(inputs[1]->addr);
float *moment = reinterpret_cast<float *>(inputs[2]->addr);
@ -109,7 +109,7 @@ bool RMSPropCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
MS_LOG(INFO) << "RMSPropCPUKernel lens:" << lens << " size_:" << size_;
LaunchRMSPropUnuseCenter(variable, mean_square, moment, gradients, learning_rate);
} else {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kRMSPropInputsNum, kernel_name_);
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCenteredRMSPropInputsNum, kernel_name_);
T *variable = reinterpret_cast<float *>(inputs[0]->addr);
T *mean_gradients = reinterpret_cast<float *>(inputs[1]->addr);
T *mean_square = reinterpret_cast<float *>(inputs[2]->addr);

View File

@ -28,8 +28,8 @@ constexpr size_t kUnsortedSegmentOutputsNum = 1;
void UnsortedSegmentSumCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
segment_ids_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
segment_ids_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto segment_ids_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);