diff --git a/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc b/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc index e1903f7dfb3..9155a85af1d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc @@ -136,6 +136,7 @@ kernel::LiteKernel *CpuStridedSliceKernelCreator(const std::vector(this->op_parameter_); - num_unit_ = static_cast(in_tensors_[kInputIndex]->shape().at(param->perm_[kNHWC_H])); - thread_h_num_ = MSMIN(thread_num_, num_unit_); - thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_); if (!InferShapeDone()) { return RET_OK; } @@ -42,9 +38,12 @@ int TransposeFp16CPUKernel::Init() { } int TransposeFp16CPUKernel::ReSize() { + TransposeParameter *param = reinterpret_cast(this->op_parameter_); + num_unit_ = static_cast(in_tensors_[kInputIndex]->shape().at(param->perm_[kNHWC_H])); + thread_h_num_ = MSMIN(thread_num_, num_unit_); + thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_); auto &in_tensor = in_tensors_.front(); auto &out_tensor = out_tensors_.front(); - auto param = reinterpret_cast(op_parameter_); auto in_shape = in_tensor->shape(); auto out_shape = out_tensor->shape(); param->strides_[param->num_axes_ - 1] = 1; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc index 3eeace0f202..73e432579c6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc @@ -325,8 +325,10 @@ kernel::LiteKernel *CpuArithmeticFp32KernelCreator(const std::vector return kernel; } +REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Slice, CpuSliceFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Slice, CpuSliceFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 369558e3536..8bfdcb5c69a 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -296,7 +296,8 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector &in_tens TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector &in_tensors) { for (const auto &tensor : in_tensors) { auto dtype = tensor->data_type(); - if (dtype == kNumberTypeFloat32 || dtype == kNumberTypeFloat16 || dtype == kNumberTypeInt8) { + if (dtype == kNumberTypeFloat32 || dtype == kNumberTypeFloat16 || dtype == kNumberTypeInt8 || + dtype == kNumberTypeInt32) { return dtype; } } @@ -341,7 +342,8 @@ kernel::SubGraphType Scheduler::GetKernelSubGraphType(kernel::LiteKernel *kernel } else if (desc.arch == kernel::KERNEL_ARCH::kCPU) { if (desc.data_type == kNumberTypeFloat16) { return kernel::kCpuFP16SubGraph; - } else if (desc.data_type == kNumberTypeFloat32 || desc.data_type == kNumberTypeInt8) { + } else if (desc.data_type == kNumberTypeFloat32 || desc.data_type == kNumberTypeInt8 || + desc.data_type == kNumberTypeInt32) { return kernel::kCpuFP32SubGraph; } }