!42582 fix reverse sequence ops type bugs r1.9

Merge pull request !42582 from wtcheng/r1.9
This commit is contained in:
i-robot 2022-09-22 02:45:58 +00:00 committed by Gitee
commit 2395b71ed7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 29 additions and 4 deletions

View File

@ -185,6 +185,17 @@ std::vector<
&ReverseSequenceCpuKernelMod::LaunchKernel<float, int32_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
&ReverseSequenceCpuKernelMod::LaunchKernel<double, int32_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
&ReverseSequenceCpuKernelMod::LaunchKernel<complex64, int32_t>,
&ReverseSequenceCpuKernelMod::ResizeKernel<complex64>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeComplex128),
&ReverseSequenceCpuKernelMod::LaunchKernel<complex128, int32_t>,
&ReverseSequenceCpuKernelMod::ResizeKernel<complex128>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
&ReverseSequenceCpuKernelMod::LaunchKernel<bool, int32_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
&ReverseSequenceCpuKernelMod::LaunchKernel<int8_t, int64_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
@ -207,7 +218,18 @@ std::vector<
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
&ReverseSequenceCpuKernelMod::LaunchKernel<float, int64_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
&ReverseSequenceCpuKernelMod::LaunchKernel<double, int64_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<double>}};
&ReverseSequenceCpuKernelMod::LaunchKernel<double, int64_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
&ReverseSequenceCpuKernelMod::LaunchKernel<complex64, int64_t>,
&ReverseSequenceCpuKernelMod::ResizeKernel<complex64>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128),
&ReverseSequenceCpuKernelMod::LaunchKernel<complex128, int64_t>,
&ReverseSequenceCpuKernelMod::ResizeKernel<complex128>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
&ReverseSequenceCpuKernelMod::LaunchKernel<bool, int64_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<bool>}};
std::vector<KernelAttr> ReverseSequenceCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REVERSE_SEQUENCE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REVERSE_SEQUENCE_CPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_REVERSE_SEQUENCE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_REVERSE_SEQUENCE_CPU_KERNEL_H_
#include <vector>
#include <map>
@ -24,11 +24,14 @@
#include <utility>
#include <string>
#include <tuple>
#include <complex>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
using complex64 = std::complex<float>;
using complex128 = std::complex<double>;
constexpr auto kUnknown = "Unknown";
class ReverseSequenceCpuKernelMod : public NativeCpuKernelMod {
@ -84,4 +87,4 @@ class ReverseSequenceCpuKernelMod : public NativeCpuKernelMod {
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REVERSE_SEQUENCE_CPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_REVERSE_SEQUENCE_CPU_KERNEL_H_