diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/slice_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/slice_gpu_kernel.cc index 34507cdd296..211547c9192 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/slice_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/slice_gpu_kernel.cc @@ -26,150 +26,163 @@ std::unique_ptr CreateSliceKernelPtr(const std::s using SlicePtrCreatorFunc = std::function(const std::string &, const uint32_t &)>; -const std::vector> kernel_attr = {{KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat64), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat32), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat16), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt16), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt8), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeUInt64), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeUInt32), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeUInt16), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeUInt8), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeBool) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeBool), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat64), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat16), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt64), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt16), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt8), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeUInt64), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeUInt32), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeUInt16), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeUInt8), - CreateSliceKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeBool) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeBool), - CreateSliceKernelPtr}}; +const std::vector> kernel_attr = { + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateSliceKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateSliceKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CreateSliceKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateSliceKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateSliceKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CreateSliceKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CreateSliceKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), CreateSliceKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), CreateSliceKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), CreateSliceKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CreateSliceKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt16), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt8), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt64), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt32), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt16), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt8), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeBool), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt64), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt16), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt8), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt64), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt32), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt16), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt8), + CreateSliceKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeBool), + CreateSliceKernelPtr}}; } // namespace bool SliceGpuKernelMod::Launch(const std::vector &inputs, const std::vector &workspace, diff --git a/mindspore/ccsrc/plugin/device/gpu/optimizer/reg_gpu_const_input_to_attr.h b/mindspore/ccsrc/plugin/device/gpu/optimizer/reg_gpu_const_input_to_attr.h index e507e91447c..9b354fa60e9 100644 --- a/mindspore/ccsrc/plugin/device/gpu/optimizer/reg_gpu_const_input_to_attr.h +++ b/mindspore/ccsrc/plugin/device/gpu/optimizer/reg_gpu_const_input_to_attr.h @@ -67,6 +67,7 @@ RER_GPU_STATIC_CONST_TO_ATTR(kSubscalarOpName, 1); RER_GPU_STATIC_CONST_TO_ATTR(kTensorCopySlicesOpName, 2, 3, 4); RER_GPU_STATIC_CONST_TO_ATTR(kTileOpName, 1); RER_GPU_STATIC_CONST_TO_ATTR(kTransposeOpName, 1); +RER_GPU_STATIC_CONST_TO_ATTR(kSliceOpName, 1, 2); } // namespace mindspore::opt #endif // MINDSPORE_CCSRC_PLUGIN_GPU_OPTIMIZER_REG_GPU_CONST_INPUT_TO_ATTR_H_