add complete sequenceslice cpu registers

This commit is contained in:
yiyanzhi_akane 2023-02-17 15:00:41 +08:00
parent 212213ea60
commit a17e57a121
2 changed files with 221 additions and 9 deletions

View File

@ -124,6 +124,55 @@ bool SequenceSliceCpuKernelMod::Launch(const std::vector<KernelTensorPtr> &input
std::vector<std::pair<KernelAttr, SequenceSliceCpuKernelMod::SequenceSliceFunc>> SequenceSliceCpuKernelMod::func_list_ = std::vector<std::pair<KernelAttr, SequenceSliceCpuKernelMod::SequenceSliceFunc>> SequenceSliceCpuKernelMod::func_list_ =
{{KernelAttr() {{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32),
&SequenceSliceCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32),
&SequenceSliceCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32),
&SequenceSliceCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32),
&SequenceSliceCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32),
&SequenceSliceCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32),
&SequenceSliceCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32),
&SequenceSliceCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32) .AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
@ -131,12 +180,110 @@ std::vector<std::pair<KernelAttr, SequenceSliceCpuKernelMod::SequenceSliceFunc>>
.AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32), .AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32),
&SequenceSliceCpuKernelMod::LaunchKernel<float>}, &SequenceSliceCpuKernelMod::LaunchKernel<float>},
{KernelAttr() {KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat64) .AddInputAttr(kObjectTypeTuple, kNumberTypeDouble)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) .AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat64), .AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble),
&SequenceSliceCpuKernelMod::LaunchKernel<double>}, &SequenceSliceCpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeDouble)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble),
&SequenceSliceCpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeDouble)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble),
&SequenceSliceCpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeDouble)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble),
&SequenceSliceCpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeDouble)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble),
&SequenceSliceCpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeDouble)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble),
&SequenceSliceCpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeDouble)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble),
&SequenceSliceCpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeDouble)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble),
&SequenceSliceCpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32),
&SequenceSliceCpuKernelMod::LaunchKernel<int>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32),
&SequenceSliceCpuKernelMod::LaunchKernel<int>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32),
&SequenceSliceCpuKernelMod::LaunchKernel<int>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32),
&SequenceSliceCpuKernelMod::LaunchKernel<int>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32),
&SequenceSliceCpuKernelMod::LaunchKernel<int>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32),
&SequenceSliceCpuKernelMod::LaunchKernel<int>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32),
&SequenceSliceCpuKernelMod::LaunchKernel<int>},
{KernelAttr() {KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32) .AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
@ -144,6 +291,55 @@ std::vector<std::pair<KernelAttr, SequenceSliceCpuKernelMod::SequenceSliceFunc>>
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32), .AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32),
&SequenceSliceCpuKernelMod::LaunchKernel<int>}, &SequenceSliceCpuKernelMod::LaunchKernel<int>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
&SequenceSliceCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
&SequenceSliceCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
&SequenceSliceCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
&SequenceSliceCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
&SequenceSliceCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
&SequenceSliceCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
&SequenceSliceCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr() {KernelAttr()
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)

View File

@ -29,6 +29,18 @@
namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
namespace { namespace {
int64_t SequenceSliceGetValue(const std::string &prim_name, const std::string &attr_name, const AbstractBasePtr &abs) {
auto build_type = abs->BuildType();
auto build_value = abs->BuildValue();
if (build_type == kInt32) {
return GetValue<int32_t>(build_value);
} else if (build_type == kInt64) {
return GetValue<int64_t>(build_value);
} else {
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the type of '" << attr_name
<< "' should be int32, int64 but got: " << abs->BuildType()->ToString();
}
}
AbstractBasePtr SliceInferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { AbstractBasePtr SliceInferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
@ -60,14 +72,18 @@ AbstractBasePtr SliceInferInner(const PrimitivePtr &primitive, const std::vector
// all value is known // all value is known
if (start_abs->BuildValue() != kAnyValue && end_abs->BuildValue() != kAnyValue && if (start_abs->BuildValue() != kAnyValue && end_abs->BuildValue() != kAnyValue &&
step_abs->BuildValue() != kAnyValue) { step_abs->BuildValue() != kAnyValue) {
auto start_v = GetValue<int64_t>(start_abs->BuildValue()); int64_t start_v, end_v, step_v;
auto end_v = GetValue<int64_t>(end_abs->BuildValue()); const std::string start_str = "start";
auto step_v = GetValue<int64_t>(step_abs->BuildValue()); const std::string end_str = "end";
const std::string step_str = "step";
start_v = SequenceSliceGetValue(prim_name, start_str, start_abs);
end_v = SequenceSliceGetValue(prim_name, end_str, end_abs);
step_v = SequenceSliceGetValue(prim_name, step_str, step_abs);
int64_t len = seq_abs->elements().size(); int64_t len = seq_abs->elements().size();
auto output_size = SequenceSliceGetOutputSize(start_v, end_v, step_v, len); auto output_size = SequenceSliceGetOutputSize(start_v, end_v, step_v, len);
abstract::AbstractBasePtrList abs{}; abstract::AbstractBasePtrList abs{};
for (int64_t i = 0; i < output_size; i++) { for (int64_t i = 0; i < output_size; i++) {
abs.push_back(std::make_shared<abstract::AbstractScalar>(kAnyValue, kInt64)); abs.push_back(std::make_shared<abstract::AbstractScalar>(kAnyValue, seq_abs->ElementsType()[0]));
} }
auto ret = std::make_shared<abstract::AbstractTuple>(abs); auto ret = std::make_shared<abstract::AbstractTuple>(abs);
return ret; return ret;