!42582 fix reverse sequence ops type bugs r1.9

Merge pull request !42582 from wtcheng/r1.9
This commit is contained in:
i-robot 2022-09-22 02:45:58 +00:00 committed by Gitee
commit 2395b71ed7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 29 additions and 4 deletions

View File

@ -185,6 +185,17 @@ std::vector<
&ReverseSequenceCpuKernelMod::LaunchKernel<float, int32_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<float>}, &ReverseSequenceCpuKernelMod::LaunchKernel<float, int32_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
&ReverseSequenceCpuKernelMod::LaunchKernel<double, int32_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<double>}, &ReverseSequenceCpuKernelMod::LaunchKernel<double, int32_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
&ReverseSequenceCpuKernelMod::LaunchKernel<complex64, int32_t>,
&ReverseSequenceCpuKernelMod::ResizeKernel<complex64>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeComplex128),
&ReverseSequenceCpuKernelMod::LaunchKernel<complex128, int32_t>,
&ReverseSequenceCpuKernelMod::ResizeKernel<complex128>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
&ReverseSequenceCpuKernelMod::LaunchKernel<bool, int32_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
&ReverseSequenceCpuKernelMod::LaunchKernel<int8_t, int64_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<int8_t>}, &ReverseSequenceCpuKernelMod::LaunchKernel<int8_t, int64_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
@ -207,7 +218,18 @@ std::vector<
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
&ReverseSequenceCpuKernelMod::LaunchKernel<float, int64_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<float>}, &ReverseSequenceCpuKernelMod::LaunchKernel<float, int64_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
&ReverseSequenceCpuKernelMod::LaunchKernel<double, int64_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<double>}}; &ReverseSequenceCpuKernelMod::LaunchKernel<double, int64_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
&ReverseSequenceCpuKernelMod::LaunchKernel<complex64, int64_t>,
&ReverseSequenceCpuKernelMod::ResizeKernel<complex64>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128),
&ReverseSequenceCpuKernelMod::LaunchKernel<complex128, int64_t>,
&ReverseSequenceCpuKernelMod::ResizeKernel<complex128>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
&ReverseSequenceCpuKernelMod::LaunchKernel<bool, int64_t>, &ReverseSequenceCpuKernelMod::ResizeKernel<bool>}};
std::vector<KernelAttr> ReverseSequenceCpuKernelMod::GetOpSupport() { std::vector<KernelAttr> ReverseSequenceCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list; std::vector<KernelAttr> support_list;

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REVERSE_SEQUENCE_CPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_REVERSE_SEQUENCE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REVERSE_SEQUENCE_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_REVERSE_SEQUENCE_CPU_KERNEL_H_
#include <vector> #include <vector>
#include <map> #include <map>
@ -24,11 +24,14 @@
#include <utility> #include <utility>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <complex>
#include "plugin/device/cpu/kernel/cpu_kernel.h" #include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h" #include "plugin/factory/ms_factory.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
using complex64 = std::complex<float>;
using complex128 = std::complex<double>;
constexpr auto kUnknown = "Unknown"; constexpr auto kUnknown = "Unknown";
class ReverseSequenceCpuKernelMod : public NativeCpuKernelMod { class ReverseSequenceCpuKernelMod : public NativeCpuKernelMod {
@ -84,4 +87,4 @@ class ReverseSequenceCpuKernelMod : public NativeCpuKernelMod {
}; };
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REVERSE_SEQUENCE_CPU_KERNEL_H_ #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_REVERSE_SEQUENCE_CPU_KERNEL_H_