From a17e57a121e31cc62a7beabe3b0e5d2f781e2d88 Mon Sep 17 00:00:00 2001 From: yiyanzhi_akane Date: Fri, 17 Feb 2023 15:00:41 +0800 Subject: [PATCH] add complete sequenceslice cpu registers --- .../sequence/sequence_slice_cpu_kernel.cc | 206 +++++++++++++++++- mindspore/core/ops/sequence_slice.cc | 24 +- 2 files changed, 221 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sequence/sequence_slice_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/sequence/sequence_slice_cpu_kernel.cc index ddc4821e20d..966e96c08c7 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/sequence/sequence_slice_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sequence/sequence_slice_cpu_kernel.cc @@ -124,6 +124,55 @@ bool SequenceSliceCpuKernelMod::Launch(const std::vector &input std::vector> SequenceSliceCpuKernelMod::func_list_ = {{KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() .AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32) .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) @@ -131,12 +180,110 @@ std::vector> .AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32), &SequenceSliceCpuKernelMod::LaunchKernel}, {KernelAttr() - .AddInputAttr(kObjectTypeTuple, kNumberTypeFloat64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat64), + .AddInputAttr(kObjectTypeTuple, kNumberTypeDouble) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble), &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeDouble) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeDouble) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeDouble) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeDouble) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeDouble) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeDouble) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeDouble) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32), + &SequenceSliceCpuKernelMod::LaunchKernel}, {KernelAttr() .AddInputAttr(kObjectTypeTuple, kNumberTypeInt32) .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) @@ -144,6 +291,55 @@ std::vector> .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) .AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32), &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64), + &SequenceSliceCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32) + .AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64), + &SequenceSliceCpuKernelMod::LaunchKernel}, {KernelAttr() .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) diff --git a/mindspore/core/ops/sequence_slice.cc b/mindspore/core/ops/sequence_slice.cc index a948d1068ef..1548a50687b 100644 --- a/mindspore/core/ops/sequence_slice.cc +++ b/mindspore/core/ops/sequence_slice.cc @@ -29,6 +29,18 @@ namespace mindspore { namespace ops { namespace { +int64_t SequenceSliceGetValue(const std::string &prim_name, const std::string &attr_name, const AbstractBasePtr &abs) { + auto build_type = abs->BuildType(); + auto build_value = abs->BuildValue(); + if (build_type == kInt32) { + return GetValue(build_value); + } else if (build_type == kInt64) { + return GetValue(build_value); + } else { + MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the type of '" << attr_name + << "' should be int32, int64 but got: " << abs->BuildType()->ToString(); + } +} AbstractBasePtr SliceInferInner(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); @@ -60,14 +72,18 @@ AbstractBasePtr SliceInferInner(const PrimitivePtr &primitive, const std::vector // all value is known if (start_abs->BuildValue() != kAnyValue && end_abs->BuildValue() != kAnyValue && step_abs->BuildValue() != kAnyValue) { - auto start_v = GetValue(start_abs->BuildValue()); - auto end_v = GetValue(end_abs->BuildValue()); - auto step_v = GetValue(step_abs->BuildValue()); + int64_t start_v, end_v, step_v; + const std::string start_str = "start"; + const std::string end_str = "end"; + const std::string step_str = "step"; + start_v = SequenceSliceGetValue(prim_name, start_str, start_abs); + end_v = SequenceSliceGetValue(prim_name, end_str, end_abs); + step_v = SequenceSliceGetValue(prim_name, step_str, step_abs); int64_t len = seq_abs->elements().size(); auto output_size = SequenceSliceGetOutputSize(start_v, end_v, step_v, len); abstract::AbstractBasePtrList abs{}; for (int64_t i = 0; i < output_size; i++) { - abs.push_back(std::make_shared(kAnyValue, kInt64)); + abs.push_back(std::make_shared(kAnyValue, seq_abs->ElementsType()[0])); } auto ret = std::make_shared(abs); return ret;