unsqueeze support int32

This commit is contained in:
chenjianping 2020-08-17 09:59:48 +08:00
parent 857c0301a8
commit baf64e498e
4 changed files with 17 additions and 20 deletions

View File

@ -28,17 +28,15 @@ using mindspore::schema::PrimitiveType_Unsqueeze;
namespace mindspore::kernel { namespace mindspore::kernel {
int UnsqueezeCPUKernel::Init() { int UnsqueezeCPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) { if (!InferShapeDone()) {
set_need_reinit();
return RET_OK; return RET_OK;
} }
int ret = ReSize(); return ReSize();
return ret;
} }
int UnsqueezeCPUKernel::ReSize() { int UnsqueezeCPUKernel::ReSize() {
data_size_ = in_tensors_.at(0)->ElementsNum(); data_size_ = in_tensors_.at(0)->ElementsNum();
thread_sz_count_ = MSMIN(thread_count_, data_size_); thread_sz_count_ = MSMIN(context_->thread_num_, data_size_);
thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_);
return RET_OK; return RET_OK;
} }
@ -48,7 +46,7 @@ int UnsqueezeCPUKernel::DoUnsqueeze(int task_id) {
if (size == 0) { if (size == 0) {
return RET_OK; return RET_OK;
} }
int offset = task_id * thread_sz_stride_; size_t offset = task_id * thread_sz_stride_ * sizeof(float);
int ret = Unsqueeze(in_ptr_ + offset, out_ptr_ + offset, size * sizeof(float)); int ret = Unsqueeze(in_ptr_ + offset, out_ptr_ + offset, size * sizeof(float));
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "UnsqueezeRun error task_id[" << task_id << "] error_code[" << ret << "]"; MS_LOG(ERROR) << "UnsqueezeRun error task_id[" << task_id << "] error_code[" << ret << "]";
@ -73,8 +71,8 @@ int UnsqueezeCPUKernel::Run() {
MS_LOG(ERROR) << "Prepare failed."; MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR; return RET_ERROR;
} }
in_ptr_ = reinterpret_cast<float *>(in_tensors_.at(0)->Data()); in_ptr_ = reinterpret_cast<int8_t *>(in_tensors_.at(0)->Data());
out_ptr_ = reinterpret_cast<float *>(out_tensors_.at(0)->Data()); out_ptr_ = reinterpret_cast<int8_t *>(out_tensors_.at(0)->Data());
ret = LiteBackendParallelLaunch(UnsqueezeRun, this, thread_sz_count_); ret = LiteBackendParallelLaunch(UnsqueezeRun, this, thread_sz_count_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "UnsqueezeRun error error_code[" << ret << "]"; MS_LOG(ERROR) << "UnsqueezeRun error error_code[" << ret << "]";
@ -85,19 +83,19 @@ int UnsqueezeCPUKernel::Run() {
kernel::LiteKernel *CpuUnsqueezeFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, kernel::LiteKernel *CpuUnsqueezeFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx, OpParameter *parameter, const lite::Context *ctx,
const kernel::KernelKey &desc, const lite::Primitive *primitive) { const kernel::KernelKey &desc, const lite::Primitive *primitive) {
MS_ASSERT(opParameter != nullptr); MS_ASSERT(parameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Unsqueeze); MS_ASSERT(desc.type == schema::PrimitiveType_Unsqueeze);
auto *kernel = new (std::nothrow) UnsqueezeCPUKernel(opParameter, inputs, outputs, ctx, primitive); auto *kernel = new (std::nothrow) UnsqueezeCPUKernel(parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "new UnsqueezeCPUKernel fail!"; MS_LOG(ERROR) << "new UnsqueezeCPUKernel fail!";
return nullptr; return nullptr;
} }
auto ret = kernel->Init(); auto ret = kernel->Init();
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
delete kernel; delete kernel;
return nullptr; return nullptr;
} }
@ -105,4 +103,5 @@ kernel::LiteKernel *CpuUnsqueezeFp32KernelCreator(const std::vector<lite::tensor
} }
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Unsqueeze, CpuUnsqueezeFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Unsqueeze, CpuUnsqueezeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Unsqueeze, CpuUnsqueezeFp32KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

View File

@ -29,7 +29,7 @@ class UnsqueezeCPUKernel : public LiteKernel {
UnsqueezeCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, UnsqueezeCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::Primitive *primitive) const lite::Primitive *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) {} : LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~UnsqueezeCPUKernel() = default; ~UnsqueezeCPUKernel() = default;
int Init() override; int Init() override;
@ -38,13 +38,11 @@ class UnsqueezeCPUKernel : public LiteKernel {
int DoUnsqueeze(int task_id); int DoUnsqueeze(int task_id);
private: private:
int thread_count_;
int thread_sz_count_; int thread_sz_count_;
int thread_sz_stride_; int thread_sz_stride_;
int data_size_; int data_size_;
float *in_ptr_; int8_t *in_ptr_;
float *out_ptr_; int8_t *out_ptr_;
const Context *ctx_;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel

View File

@ -18,7 +18,7 @@
#include <string.h> #include <string.h>
#include "nnacl/errorcode.h" #include "nnacl/errorcode.h"
int Unsqueeze(float *input_ptr, float *output_ptr, size_t data_size) { int Unsqueeze(const int8_t *input_ptr, int8_t *output_ptr, size_t data_size) {
memcpy(output_ptr, input_ptr, data_size); memcpy(output_ptr, input_ptr, data_size);
return NNACL_OK; return NNACL_OK;
} }

View File

@ -30,7 +30,7 @@ typedef struct UnsqueezeParameter {
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
int Unsqueeze(float *input_ptr, float *output_ptr, size_t data_size); int Unsqueeze(const int8_t *input_ptr, int8_t *output_ptr, size_t data_size);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif