forked from mindspore-Ecosystem/mindspore
!4560 [MS][LITE][Develop]unsqueeze support int32
Merge pull request !4560 from chenjianping/lite_dev2
This commit is contained in:
commit
cde696477c
|
@ -28,17 +28,15 @@ using mindspore::schema::PrimitiveType_Unsqueeze;
|
||||||
|
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
int UnsqueezeCPUKernel::Init() {
|
int UnsqueezeCPUKernel::Init() {
|
||||||
if (context_->infer_shape_interrupt_ && !context_->running_) {
|
if (!InferShapeDone()) {
|
||||||
set_need_reinit();
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
int ret = ReSize();
|
return ReSize();
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int UnsqueezeCPUKernel::ReSize() {
|
int UnsqueezeCPUKernel::ReSize() {
|
||||||
data_size_ = in_tensors_.at(0)->ElementsNum();
|
data_size_ = in_tensors_.at(0)->ElementsNum();
|
||||||
thread_sz_count_ = MSMIN(thread_count_, data_size_);
|
thread_sz_count_ = MSMIN(context_->thread_num_, data_size_);
|
||||||
thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_);
|
thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
@ -48,7 +46,7 @@ int UnsqueezeCPUKernel::DoUnsqueeze(int task_id) {
|
||||||
if (size == 0) {
|
if (size == 0) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
int offset = task_id * thread_sz_stride_;
|
size_t offset = task_id * thread_sz_stride_ * sizeof(float);
|
||||||
int ret = Unsqueeze(in_ptr_ + offset, out_ptr_ + offset, size * sizeof(float));
|
int ret = Unsqueeze(in_ptr_ + offset, out_ptr_ + offset, size * sizeof(float));
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "UnsqueezeRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
MS_LOG(ERROR) << "UnsqueezeRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||||
|
@ -73,8 +71,8 @@ int UnsqueezeCPUKernel::Run() {
|
||||||
MS_LOG(ERROR) << "Prepare failed.";
|
MS_LOG(ERROR) << "Prepare failed.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
in_ptr_ = reinterpret_cast<float *>(in_tensors_.at(0)->Data());
|
in_ptr_ = reinterpret_cast<int8_t *>(in_tensors_.at(0)->Data());
|
||||||
out_ptr_ = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
|
out_ptr_ = reinterpret_cast<int8_t *>(out_tensors_.at(0)->Data());
|
||||||
ret = LiteBackendParallelLaunch(UnsqueezeRun, this, thread_sz_count_);
|
ret = LiteBackendParallelLaunch(UnsqueezeRun, this, thread_sz_count_);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "UnsqueezeRun error error_code[" << ret << "]";
|
MS_LOG(ERROR) << "UnsqueezeRun error error_code[" << ret << "]";
|
||||||
|
@ -85,19 +83,19 @@ int UnsqueezeCPUKernel::Run() {
|
||||||
|
|
||||||
kernel::LiteKernel *CpuUnsqueezeFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
kernel::LiteKernel *CpuUnsqueezeFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||||
OpParameter *opParameter, const lite::Context *ctx,
|
OpParameter *parameter, const lite::Context *ctx,
|
||||||
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
|
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
|
||||||
MS_ASSERT(opParameter != nullptr);
|
MS_ASSERT(parameter != nullptr);
|
||||||
MS_ASSERT(desc.type == schema::PrimitiveType_Unsqueeze);
|
MS_ASSERT(desc.type == schema::PrimitiveType_Unsqueeze);
|
||||||
auto *kernel = new (std::nothrow) UnsqueezeCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
auto *kernel = new (std::nothrow) UnsqueezeCPUKernel(parameter, inputs, outputs, ctx, primitive);
|
||||||
if (kernel == nullptr) {
|
if (kernel == nullptr) {
|
||||||
MS_LOG(ERROR) << "new UnsqueezeCPUKernel fail!";
|
MS_LOG(ERROR) << "new UnsqueezeCPUKernel fail!";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto ret = kernel->Init();
|
auto ret = kernel->Init();
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ << ", type: "
|
||||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
|
||||||
delete kernel;
|
delete kernel;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -105,4 +103,5 @@ kernel::LiteKernel *CpuUnsqueezeFp32KernelCreator(const std::vector<lite::tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Unsqueeze, CpuUnsqueezeFp32KernelCreator)
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Unsqueeze, CpuUnsqueezeFp32KernelCreator)
|
||||||
|
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Unsqueeze, CpuUnsqueezeFp32KernelCreator)
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
|
@ -29,7 +29,7 @@ class UnsqueezeCPUKernel : public LiteKernel {
|
||||||
UnsqueezeCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
UnsqueezeCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||||
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
||||||
const lite::Primitive *primitive)
|
const lite::Primitive *primitive)
|
||||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) {}
|
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||||
~UnsqueezeCPUKernel() = default;
|
~UnsqueezeCPUKernel() = default;
|
||||||
|
|
||||||
int Init() override;
|
int Init() override;
|
||||||
|
@ -38,13 +38,11 @@ class UnsqueezeCPUKernel : public LiteKernel {
|
||||||
int DoUnsqueeze(int task_id);
|
int DoUnsqueeze(int task_id);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int thread_count_;
|
|
||||||
int thread_sz_count_;
|
int thread_sz_count_;
|
||||||
int thread_sz_stride_;
|
int thread_sz_stride_;
|
||||||
int data_size_;
|
int data_size_;
|
||||||
float *in_ptr_;
|
int8_t *in_ptr_;
|
||||||
float *out_ptr_;
|
int8_t *out_ptr_;
|
||||||
const Context *ctx_;
|
|
||||||
};
|
};
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include "nnacl/errorcode.h"
|
#include "nnacl/errorcode.h"
|
||||||
|
|
||||||
int Unsqueeze(float *input_ptr, float *output_ptr, size_t data_size) {
|
int Unsqueeze(const int8_t *input_ptr, int8_t *output_ptr, size_t data_size) {
|
||||||
memcpy(output_ptr, input_ptr, data_size);
|
memcpy(output_ptr, input_ptr, data_size);
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,7 @@ typedef struct UnsqueezeParameter {
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
int Unsqueeze(float *input_ptr, float *output_ptr, size_t data_size);
|
int Unsqueeze(const int8_t *input_ptr, int8_t *output_ptr, size_t data_size);
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
Loading…
Reference in New Issue