revert: slice gpu attrs to input feature
This commit is contained in:
parent
a47dacb5ca
commit
6979bdd9b7
|
@ -26,150 +26,163 @@ std::unique_ptr<cukernel::GpuKernelHelperBase> CreateSliceKernelPtr(const std::s
|
|||
using SlicePtrCreatorFunc =
|
||||
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
|
||||
|
||||
const std::vector<std::pair<KernelAttr, SlicePtrCreatorFunc>> kernel_attr = {{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateSliceKernelPtr<double, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateSliceKernelPtr<float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateSliceKernelPtr<half, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
CreateSliceKernelPtr<int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
CreateSliceKernelPtr<int32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
CreateSliceKernelPtr<int16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
CreateSliceKernelPtr<char, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
CreateSliceKernelPtr<uint64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
CreateSliceKernelPtr<uint32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
CreateSliceKernelPtr<uint16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateSliceKernelPtr<uchar, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
CreateSliceKernelPtr<bool, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateSliceKernelPtr<double, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateSliceKernelPtr<float, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateSliceKernelPtr<half, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
CreateSliceKernelPtr<int64_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
CreateSliceKernelPtr<int32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
CreateSliceKernelPtr<int16_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
CreateSliceKernelPtr<char, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
CreateSliceKernelPtr<uint64_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
CreateSliceKernelPtr<uint32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
CreateSliceKernelPtr<uint16_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateSliceKernelPtr<uchar, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
CreateSliceKernelPtr<bool, int32_t>}};
|
||||
const std::vector<std::pair<KernelAttr, SlicePtrCreatorFunc>> kernel_attr = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateSliceKernelPtr<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateSliceKernelPtr<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CreateSliceKernelPtr<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateSliceKernelPtr<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateSliceKernelPtr<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CreateSliceKernelPtr<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CreateSliceKernelPtr<char>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), CreateSliceKernelPtr<uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), CreateSliceKernelPtr<uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), CreateSliceKernelPtr<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CreateSliceKernelPtr<uchar>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CreateSliceKernelPtr<bool>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateSliceKernelPtr<double, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateSliceKernelPtr<float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateSliceKernelPtr<half, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
CreateSliceKernelPtr<int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
CreateSliceKernelPtr<int32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
CreateSliceKernelPtr<int16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
CreateSliceKernelPtr<char, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
CreateSliceKernelPtr<uint64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
CreateSliceKernelPtr<uint32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
CreateSliceKernelPtr<uint16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateSliceKernelPtr<uchar, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
CreateSliceKernelPtr<bool, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateSliceKernelPtr<double, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateSliceKernelPtr<float, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateSliceKernelPtr<half, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
CreateSliceKernelPtr<int64_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
CreateSliceKernelPtr<int32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
CreateSliceKernelPtr<int16_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
CreateSliceKernelPtr<char, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
CreateSliceKernelPtr<uint64_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
CreateSliceKernelPtr<uint32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
CreateSliceKernelPtr<uint16_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateSliceKernelPtr<uchar, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
CreateSliceKernelPtr<bool, int32_t>}};
|
||||
} // namespace
|
||||
|
||||
bool SliceGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
|
|
|
@ -67,6 +67,7 @@ RER_GPU_STATIC_CONST_TO_ATTR(kSubscalarOpName, 1);
|
|||
RER_GPU_STATIC_CONST_TO_ATTR(kTensorCopySlicesOpName, 2, 3, 4);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kTileOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kTransposeOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kSliceOpName, 1, 2);
|
||||
} // namespace mindspore::opt
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_GPU_OPTIMIZER_REG_GPU_CONST_INPUT_TO_ATTR_H_
|
||||
|
|
Loading…
Reference in New Issue