From 166bdbff1c836b219a309bfc3293760a23a31a0e Mon Sep 17 00:00:00 2001 From: wutiancheng Date: Wed, 21 Sep 2022 19:27:57 +0800 Subject: [PATCH] fix reverse sequence ops type bugs r1.9 --- .../cpu/kernel/reverse_sequence_cpu_kernel.cc | 24 ++++++++++++++++++- .../cpu/kernel/reverse_sequence_cpu_kernel.h | 9 ++++--- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/reverse_sequence_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/reverse_sequence_cpu_kernel.cc index 7a4494b7f93..f0ef03dbd88 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/reverse_sequence_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/reverse_sequence_cpu_kernel.cc @@ -185,6 +185,17 @@ std::vector< &ReverseSequenceCpuKernelMod::LaunchKernel, &ReverseSequenceCpuKernelMod::ResizeKernel}, {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), &ReverseSequenceCpuKernelMod::LaunchKernel, &ReverseSequenceCpuKernelMod::ResizeKernel}, + {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64), + &ReverseSequenceCpuKernelMod::LaunchKernel, + &ReverseSequenceCpuKernelMod::ResizeKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeComplex128), + &ReverseSequenceCpuKernelMod::LaunchKernel, + &ReverseSequenceCpuKernelMod::ResizeKernel}, + {KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), + &ReverseSequenceCpuKernelMod::LaunchKernel, &ReverseSequenceCpuKernelMod::ResizeKernel}, {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), &ReverseSequenceCpuKernelMod::LaunchKernel, &ReverseSequenceCpuKernelMod::ResizeKernel}, {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), @@ -207,7 +218,18 @@ std::vector< {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), &ReverseSequenceCpuKernelMod::LaunchKernel, &ReverseSequenceCpuKernelMod::ResizeKernel}, {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), - &ReverseSequenceCpuKernelMod::LaunchKernel, &ReverseSequenceCpuKernelMod::ResizeKernel}}; + &ReverseSequenceCpuKernelMod::LaunchKernel, &ReverseSequenceCpuKernelMod::ResizeKernel}, + {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64), + &ReverseSequenceCpuKernelMod::LaunchKernel, + &ReverseSequenceCpuKernelMod::ResizeKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeComplex128), + &ReverseSequenceCpuKernelMod::LaunchKernel, + &ReverseSequenceCpuKernelMod::ResizeKernel}, + {KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), + &ReverseSequenceCpuKernelMod::LaunchKernel, &ReverseSequenceCpuKernelMod::ResizeKernel}}; std::vector ReverseSequenceCpuKernelMod::GetOpSupport() { std::vector support_list; diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/reverse_sequence_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/reverse_sequence_cpu_kernel.h index 8be91125b22..b72628430fd 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/reverse_sequence_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/reverse_sequence_cpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REVERSE_SEQUENCE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REVERSE_SEQUENCE_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_REVERSE_SEQUENCE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_REVERSE_SEQUENCE_CPU_KERNEL_H_ #include #include @@ -24,11 +24,14 @@ #include #include #include +#include #include "plugin/device/cpu/kernel/cpu_kernel.h" #include "plugin/factory/ms_factory.h" namespace mindspore { namespace kernel { +using complex64 = std::complex; +using complex128 = std::complex; constexpr auto kUnknown = "Unknown"; class ReverseSequenceCpuKernelMod : public NativeCpuKernelMod { @@ -84,4 +87,4 @@ class ReverseSequenceCpuKernelMod : public NativeCpuKernelMod { }; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REVERSE_SEQUENCE_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_REVERSE_SEQUENCE_CPU_KERNEL_H_