forked from mindspore-Ecosystem/mindspore
commit
93ea5f5992
|
@ -98,7 +98,7 @@ template <typename T>
|
|||
bool RMSPropCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &) {
|
||||
if (!use_center_) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCenteredRMSPropInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kRMSPropInputsNum, kernel_name_);
|
||||
float *variable = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
float *mean_square = reinterpret_cast<float *>(inputs[1]->addr);
|
||||
float *moment = reinterpret_cast<float *>(inputs[2]->addr);
|
||||
|
@ -109,7 +109,7 @@ bool RMSPropCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
MS_LOG(INFO) << "RMSPropCPUKernel lens:" << lens << " size_:" << size_;
|
||||
LaunchRMSPropUnuseCenter(variable, mean_square, moment, gradients, learning_rate);
|
||||
} else {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kRMSPropInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCenteredRMSPropInputsNum, kernel_name_);
|
||||
T *variable = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
T *mean_gradients = reinterpret_cast<float *>(inputs[1]->addr);
|
||||
T *mean_square = reinterpret_cast<float *>(inputs[2]->addr);
|
||||
|
|
|
@ -28,8 +28,8 @@ constexpr size_t kUnsortedSegmentOutputsNum = 1;
|
|||
void UnsortedSegmentSumCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
segment_ids_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
segment_ids_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto segment_ids_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
|
|
Loading…
Reference in New Issue