diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.cc index af3e1562446..bfe611677a2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.cc @@ -28,17 +28,15 @@ using mindspore::schema::PrimitiveType_Unsqueeze; namespace mindspore::kernel { int UnsqueezeCPUKernel::Init() { - if (context_->infer_shape_interrupt_ && !context_->running_) { - set_need_reinit(); + if (!InferShapeDone()) { return RET_OK; } - int ret = ReSize(); - return ret; + return ReSize(); } int UnsqueezeCPUKernel::ReSize() { data_size_ = in_tensors_.at(0)->ElementsNum(); - thread_sz_count_ = MSMIN(thread_count_, data_size_); + thread_sz_count_ = MSMIN(context_->thread_num_, data_size_); thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); return RET_OK; } @@ -48,7 +46,7 @@ int UnsqueezeCPUKernel::DoUnsqueeze(int task_id) { if (size == 0) { return RET_OK; } - int offset = task_id * thread_sz_stride_; + size_t offset = task_id * thread_sz_stride_ * sizeof(float); int ret = Unsqueeze(in_ptr_ + offset, out_ptr_ + offset, size * sizeof(float)); if (ret != RET_OK) { MS_LOG(ERROR) << "UnsqueezeRun error task_id[" << task_id << "] error_code[" << ret << "]"; @@ -73,8 +71,8 @@ int UnsqueezeCPUKernel::Run() { MS_LOG(ERROR) << "Prepare failed."; return RET_ERROR; } - in_ptr_ = reinterpret_cast(in_tensors_.at(0)->Data()); - out_ptr_ = reinterpret_cast(out_tensors_.at(0)->Data()); + in_ptr_ = reinterpret_cast(in_tensors_.at(0)->Data()); + out_ptr_ = reinterpret_cast(out_tensors_.at(0)->Data()); ret = LiteBackendParallelLaunch(UnsqueezeRun, this, thread_sz_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "UnsqueezeRun error error_code[" << ret << "]"; @@ -85,19 +83,19 @@ int UnsqueezeCPUKernel::Run() { kernel::LiteKernel *CpuUnsqueezeFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx, + OpParameter *parameter, const lite::Context *ctx, const kernel::KernelKey &desc, const lite::Primitive *primitive) { - MS_ASSERT(opParameter != nullptr); + MS_ASSERT(parameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Unsqueeze); - auto *kernel = new (std::nothrow) UnsqueezeCPUKernel(opParameter, inputs, outputs, ctx, primitive); + auto *kernel = new (std::nothrow) UnsqueezeCPUKernel(parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new UnsqueezeCPUKernel fail!"; return nullptr; } auto ret = kernel->Init(); if (ret != RET_OK) { - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); delete kernel; return nullptr; } @@ -105,4 +103,5 @@ kernel::LiteKernel *CpuUnsqueezeFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, const lite::Context *ctx, const lite::Primitive *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) {} + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~UnsqueezeCPUKernel() = default; int Init() override; @@ -38,13 +38,11 @@ class UnsqueezeCPUKernel : public LiteKernel { int DoUnsqueeze(int task_id); private: - int thread_count_; int thread_sz_count_; int thread_sz_stride_; int data_size_; - float *in_ptr_; - float *out_ptr_; - const Context *ctx_; + int8_t *in_ptr_; + int8_t *out_ptr_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/unsqueeze.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/unsqueeze.c index 6911b34cf80..4eba6b7edff 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/unsqueeze.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/unsqueeze.c @@ -18,7 +18,7 @@ #include #include "nnacl/errorcode.h" -int Unsqueeze(float *input_ptr, float *output_ptr, size_t data_size) { +int Unsqueeze(const int8_t *input_ptr, int8_t *output_ptr, size_t data_size) { memcpy(output_ptr, input_ptr, data_size); return NNACL_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/unsqueeze.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/unsqueeze.h index 12f37b7f30f..540916e1f0f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/unsqueeze.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/unsqueeze.h @@ -30,7 +30,7 @@ typedef struct UnsqueezeParameter { #ifdef __cplusplus extern "C" { #endif -int Unsqueeze(float *input_ptr, float *output_ptr, size_t data_size); +int Unsqueeze(const int8_t *input_ptr, int8_t *output_ptr, size_t data_size); #ifdef __cplusplus } #endif