forked from mindspore-Ecosystem/mindspore
fix cpu bug
This commit is contained in:
parent
1e15b00a31
commit
bcf3f6341c
|
@ -98,7 +98,7 @@ template <typename T>
|
||||||
bool RMSPropCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
bool RMSPropCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||||
const std::vector<kernel::AddressPtr> &) {
|
const std::vector<kernel::AddressPtr> &) {
|
||||||
if (!use_center_) {
|
if (!use_center_) {
|
||||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCenteredRMSPropInputsNum, kernel_name_);
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kRMSPropInputsNum, kernel_name_);
|
||||||
float *variable = reinterpret_cast<float *>(inputs[0]->addr);
|
float *variable = reinterpret_cast<float *>(inputs[0]->addr);
|
||||||
float *mean_square = reinterpret_cast<float *>(inputs[1]->addr);
|
float *mean_square = reinterpret_cast<float *>(inputs[1]->addr);
|
||||||
float *moment = reinterpret_cast<float *>(inputs[2]->addr);
|
float *moment = reinterpret_cast<float *>(inputs[2]->addr);
|
||||||
|
@ -109,7 +109,7 @@ bool RMSPropCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
MS_LOG(INFO) << "RMSPropCPUKernel lens:" << lens << " size_:" << size_;
|
MS_LOG(INFO) << "RMSPropCPUKernel lens:" << lens << " size_:" << size_;
|
||||||
LaunchRMSPropUnuseCenter(variable, mean_square, moment, gradients, learning_rate);
|
LaunchRMSPropUnuseCenter(variable, mean_square, moment, gradients, learning_rate);
|
||||||
} else {
|
} else {
|
||||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kRMSPropInputsNum, kernel_name_);
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCenteredRMSPropInputsNum, kernel_name_);
|
||||||
T *variable = reinterpret_cast<float *>(inputs[0]->addr);
|
T *variable = reinterpret_cast<float *>(inputs[0]->addr);
|
||||||
T *mean_gradients = reinterpret_cast<float *>(inputs[1]->addr);
|
T *mean_gradients = reinterpret_cast<float *>(inputs[1]->addr);
|
||||||
T *mean_square = reinterpret_cast<float *>(inputs[2]->addr);
|
T *mean_square = reinterpret_cast<float *>(inputs[2]->addr);
|
||||||
|
|
|
@ -28,8 +28,8 @@ constexpr size_t kUnsortedSegmentOutputsNum = 1;
|
||||||
void UnsortedSegmentSumCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
void UnsortedSegmentSumCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||||
segment_ids_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1);
|
segment_ids_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
|
||||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||||
auto segment_ids_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
auto segment_ids_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||||
|
|
Loading…
Reference in New Issue