forked from mindspore-Ecosystem/mindspore
add complete sequenceslice cpu registers
This commit is contained in:
parent
212213ea60
commit
a17e57a121
|
@ -124,6 +124,55 @@ bool SequenceSliceCpuKernelMod::Launch(const std::vector<KernelTensorPtr> &input
|
|||
|
||||
std::vector<std::pair<KernelAttr, SequenceSliceCpuKernelMod::SequenceSliceFunc>> SequenceSliceCpuKernelMod::func_list_ =
|
||||
{{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
|
@ -131,12 +180,110 @@ std::vector<std::pair<KernelAttr, SequenceSliceCpuKernelMod::SequenceSliceFunc>>
|
|||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat32),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeFloat64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeFloat64),
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeDouble)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeDouble)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeDouble)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeDouble)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeDouble)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeDouble)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeDouble)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeDouble)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeDouble),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
|
@ -144,6 +291,55 @@ std::vector<std::pair<KernelAttr, SequenceSliceCpuKernelMod::SequenceSliceFunc>>
|
|||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt32),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt32)
|
||||
.AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64),
|
||||
&SequenceSliceCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
|
||||
|
|
|
@ -29,6 +29,18 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
int64_t SequenceSliceGetValue(const std::string &prim_name, const std::string &attr_name, const AbstractBasePtr &abs) {
|
||||
auto build_type = abs->BuildType();
|
||||
auto build_value = abs->BuildValue();
|
||||
if (build_type == kInt32) {
|
||||
return GetValue<int32_t>(build_value);
|
||||
} else if (build_type == kInt64) {
|
||||
return GetValue<int64_t>(build_value);
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the type of '" << attr_name
|
||||
<< "' should be int32, int64 but got: " << abs->BuildType()->ToString();
|
||||
}
|
||||
}
|
||||
AbstractBasePtr SliceInferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
|
@ -60,14 +72,18 @@ AbstractBasePtr SliceInferInner(const PrimitivePtr &primitive, const std::vector
|
|||
// all value is known
|
||||
if (start_abs->BuildValue() != kAnyValue && end_abs->BuildValue() != kAnyValue &&
|
||||
step_abs->BuildValue() != kAnyValue) {
|
||||
auto start_v = GetValue<int64_t>(start_abs->BuildValue());
|
||||
auto end_v = GetValue<int64_t>(end_abs->BuildValue());
|
||||
auto step_v = GetValue<int64_t>(step_abs->BuildValue());
|
||||
int64_t start_v, end_v, step_v;
|
||||
const std::string start_str = "start";
|
||||
const std::string end_str = "end";
|
||||
const std::string step_str = "step";
|
||||
start_v = SequenceSliceGetValue(prim_name, start_str, start_abs);
|
||||
end_v = SequenceSliceGetValue(prim_name, end_str, end_abs);
|
||||
step_v = SequenceSliceGetValue(prim_name, step_str, step_abs);
|
||||
int64_t len = seq_abs->elements().size();
|
||||
auto output_size = SequenceSliceGetOutputSize(start_v, end_v, step_v, len);
|
||||
abstract::AbstractBasePtrList abs{};
|
||||
for (int64_t i = 0; i < output_size; i++) {
|
||||
abs.push_back(std::make_shared<abstract::AbstractScalar>(kAnyValue, kInt64));
|
||||
abs.push_back(std::make_shared<abstract::AbstractScalar>(kAnyValue, seq_abs->ElementsType()[0]));
|
||||
}
|
||||
auto ret = std::make_shared<abstract::AbstractTuple>(abs);
|
||||
return ret;
|
||||
|
|
Loading…
Reference in New Issue