forked from mindspore-Ecosystem/mindspore
add op_gather_int8 and testcase
This commit is contained in:
parent
75af54647f
commit
cf56085cf0
|
@ -19,19 +19,13 @@
|
|||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct GatherParameter {
|
||||
OpParameter op_parameter_;
|
||||
int axis_;
|
||||
int batchDims_;
|
||||
} GatherParameter;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int Gather(float *input, int outer_size, int inner_size, int limit, int *indices, int indices_element_size,
|
||||
float *output);
|
||||
int GatherInt32(const int32_t *input, int outer_size, int inner_size, int limit, int *indices,
|
||||
int indices_element_size, int32_t *output);
|
||||
int GatherInt32(const int32_t *input, int outer_size, int inner_size, int limit, int *indices, int indices_element_size,
|
||||
int32_t *output);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_PARAMETER_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct GatherParameter {
|
||||
OpParameter op_parameter_;
|
||||
int axis_;
|
||||
int batchDims_;
|
||||
} GatherParameter;
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_PARAMETER_H_
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/int8/gatherNd_int8.h"
|
||||
#include <string.h>
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int GatherNdInt8(int8_t *input, int8_t *output, int *in_offset, int area, int count, GatherQuantArg param) {
|
||||
double alpha = param.alpha_;
|
||||
int z1 = param.zp_in_;
|
||||
int z2 = param.zp_out_;
|
||||
for (int i = 0; i < count; ++i) {
|
||||
for (int j = 0; j < area; ++j) {
|
||||
int32_t tmp = round(alpha * (input[in_offset[i] + j] - z1)) + z2;
|
||||
tmp = tmp > 127 ? 127 : tmp;
|
||||
tmp = tmp < -128 ? -128 : tmp;
|
||||
output[area * i + j] = (int8_t)tmp;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHERND_INT8_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHERND_INT8_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int GatherNdInt8(int8_t *in_data, int8_t *out_data, int *in_offset, int area, int count, GatherQuantArg param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHERND_INT8_H_
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
#include "nnacl/int8/gather_int8.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int GatherInt8(int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, int *indices,
|
||||
int indices_element_size, GatherQuantArg para) {
|
||||
double alpha = para.alpha_;
|
||||
int z1 = para.zp_in_;
|
||||
int z2 = para.zp_out_;
|
||||
int i, m, j;
|
||||
for (m = 0; m < outer_size; ++m) {
|
||||
const int8_t *inputm = in_data + inner_size * m * limit;
|
||||
int8_t *outputm = out_data + inner_size * m * indices_element_size;
|
||||
for (i = 0; i < indices_element_size; ++i) {
|
||||
if (indices[i] < 0 || indices[i] > limit) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
for (j = 0; j < inner_size; ++j) {
|
||||
int32_t tmp = round(alpha * (inputm[indices[i] * inner_size + j] - z1)) + z2;
|
||||
tmp = tmp > 127 ? 127 : tmp;
|
||||
tmp = tmp < -128 ? -128 : tmp;
|
||||
outputm[i * inner_size + j] = (int8_t)tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHER_INT8_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHER_INT8_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int GatherInt8(int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, int *indices,
|
||||
int indices_element_size, GatherQuantArg para);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHER_INT8_H_
|
|
@ -159,6 +159,12 @@ typedef struct ArithSelfQuantArg {
|
|||
int shift_right_;
|
||||
} ArithSelfQuantArg;
|
||||
|
||||
typedef struct GatherQuantArg {
|
||||
double alpha_;
|
||||
int zp_in_;
|
||||
int zp_out_;
|
||||
} GatherQuantArg;
|
||||
|
||||
typedef struct SplitQuantArg {
|
||||
QuantArg in_args_;
|
||||
QuantArg out_args_[20];
|
||||
|
|
|
@ -144,7 +144,7 @@
|
|||
#include "nnacl/transpose.h"
|
||||
#include "nnacl/split_parameter.h"
|
||||
#include "nnacl/squeeze.h"
|
||||
#include "nnacl/fp32/gather.h"
|
||||
#include "nnacl/gather_parameter.h"
|
||||
#include "nnacl/fp32/reverse.h"
|
||||
#include "nnacl/reverse_sequence.h"
|
||||
#include "nnacl/fp32/unique.h"
|
||||
|
|
|
@ -13,9 +13,10 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include "src/runtime/kernel/arm/fp32/gather.h"
|
||||
#include <vector>
|
||||
#include "nnacl/gather_parameter.h"
|
||||
#include "nnacl/fp32/gather.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GATHER_H_
|
||||
|
||||
#include <vector>
|
||||
#include "nnacl/fp32/gather.h"
|
||||
#include "nnacl/gather_parameter.h"
|
||||
#include "src/lite_kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
|
|
@ -0,0 +1,166 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/int8/gatherNd_int8.h"
|
||||
#include <string.h>
|
||||
#include <vector>
|
||||
#include "schema/model_generated.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "nnacl/int8/gatherNd_int8.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_GatherNd;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
GatherNdInt8CPUKernel::~GatherNdInt8CPUKernel() {
|
||||
if (in_offset_ != nullptr) {
|
||||
free(in_offset_);
|
||||
in_offset_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int GatherNdInt8CPUKernel::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int GatherNdInt8CPUKernel::ReSize() {
|
||||
if (in_offset_ != nullptr) {
|
||||
free(in_offset_);
|
||||
in_offset_ = nullptr;
|
||||
}
|
||||
auto in_quant_args = in_tensors_.at(0)->GetQuantParams();
|
||||
auto ind_quant_args = in_tensors_.at(1)->GetQuantParams();
|
||||
auto out_quant_args = out_tensors_.at(0)->GetQuantParams();
|
||||
param_.alpha_ = in_quant_args.front().scale / out_quant_args.front().scale;
|
||||
param_.zp_in_ = in_quant_args.front().zeroPoint;
|
||||
param_.zp_out_ = out_quant_args.front().zeroPoint;
|
||||
|
||||
auto indices_tensor = in_tensors_.at(1);
|
||||
auto indices_shape = indices_tensor->shape();
|
||||
int indices_rank = indices_shape.size();
|
||||
count_ = 1;
|
||||
for (int i = 0; i < indices_rank - 1; ++i) {
|
||||
count_ *= indices_shape[i];
|
||||
}
|
||||
|
||||
in_offset_ = reinterpret_cast<int *>(malloc(count_ * sizeof(int)));
|
||||
if (in_offset_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GatherNdInt8 Malloc in_offset_ error!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
(void)memset(in_offset_, 0, count_ * sizeof(int));
|
||||
|
||||
thread_sz_count_ = MSMIN(thread_count_, count_);
|
||||
thread_sz_stride_ = UP_DIV(count_, thread_sz_count_);
|
||||
|
||||
auto in_shape = in_tensors_.front()->shape();
|
||||
int in_rank = in_shape.size();
|
||||
int idx_lastshape = indices_shape[indices_rank - 1];
|
||||
auto indices_ptr = reinterpret_cast<int8_t *>(indices_tensor->Data());
|
||||
area_ = 1;
|
||||
for (int i = idx_lastshape; i < in_rank; ++i) {
|
||||
area_ *= in_shape[i];
|
||||
}
|
||||
std::vector<int> in_stride(in_rank);
|
||||
in_stride[in_rank - 1] = 1;
|
||||
for (int i = in_rank - 2; i >= 0; --i) {
|
||||
in_stride[i] = in_shape[i + 1] * in_stride[i + 1];
|
||||
}
|
||||
|
||||
int idx_stride = idx_lastshape;
|
||||
for (int j = 0; j < count_; ++j) {
|
||||
for (int k = 0; k < idx_lastshape; ++k) {
|
||||
int tmp = static_cast<int>(
|
||||
round((indices_ptr[j * idx_stride + k] - ind_quant_args.front().zeroPoint) * ind_quant_args.front().scale));
|
||||
in_offset_[j] += tmp * in_stride[k];
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GatherNdInt8CPUKernel::DoGatherNd(int task_id) {
|
||||
int count = MSMIN(thread_sz_stride_, count_ - task_id * thread_sz_stride_);
|
||||
if (count <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
int offset = task_id * thread_sz_stride_;
|
||||
auto ret = GatherNdInt8(in_ptr_, out_ptr_ + offset * area_, in_offset_ + offset, area_, count, param_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GatherNdRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GatherNdInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
||||
auto g_kernel = reinterpret_cast<GatherNdInt8CPUKernel *>(cdata);
|
||||
auto ret = g_kernel->DoGatherNd(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GatherNdRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GatherNdInt8CPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
in_ptr_ = reinterpret_cast<int8_t *>(in_tensors_.front()->Data());
|
||||
out_ptr_ = reinterpret_cast<int8_t *>(out_tensors_.front()->Data());
|
||||
auto ret = LiteBackendParallelLaunch(GatherNdInt8Run, this, thread_sz_count_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "gatherNd error error_code[" << ret << "]";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuGatherNdInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const lite::Context *ctx,
|
||||
const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_GatherNd);
|
||||
|
||||
auto *kernel = new (std::nothrow) GatherNdInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_GatherNd, CpuGatherNdInt8KernelCreator)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GATHERND_INT8_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GATHERND_INT8_H_
|
||||
|
||||
#include <vector>
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "src/lite_kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class GatherNdInt8CPUKernel : public LiteKernel {
|
||||
public:
|
||||
GatherNdInt8CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {}
|
||||
~GatherNdInt8CPUKernel() override;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoGatherNd(int task_id);
|
||||
|
||||
private:
|
||||
int thread_count_;
|
||||
int thread_sz_count_;
|
||||
int thread_sz_stride_;
|
||||
int count_;
|
||||
int area_;
|
||||
int *in_offset_ = nullptr;
|
||||
int8_t *in_ptr_;
|
||||
int8_t *out_ptr_;
|
||||
GatherQuantArg param_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GATHERND_INT8_H_
|
|
@ -0,0 +1,163 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/runtime/kernel/arm/int8/gather_int8.h"
|
||||
#include <vector>
|
||||
#include "nnacl/gather_parameter.h"
|
||||
#include "nnacl/int8/gather_int8.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Gather;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int GatherInt8CPUKernel::Init() {
|
||||
axis_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_;
|
||||
batchDims_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->batchDims_;
|
||||
auto in_quant_args = in_tensors_.at(0)->GetQuantParams();
|
||||
auto ind_quant_args = in_tensors_.at(1)->GetQuantParams();
|
||||
auto out_quant_args = out_tensors_.at(0)->GetQuantParams();
|
||||
param_.alpha_ = in_quant_args.front().scale / out_quant_args.front().scale;
|
||||
param_.zp_in_ = in_quant_args.front().zeroPoint;
|
||||
param_.zp_out_ = out_quant_args.front().zeroPoint;
|
||||
|
||||
auto indices_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(1)->Data());
|
||||
if (indices_ != nullptr) {
|
||||
free(indices_);
|
||||
indices_ = nullptr;
|
||||
}
|
||||
int count = in_tensors_.at(1)->ElementsNum();
|
||||
indices_ = reinterpret_cast<int *>(malloc(count * sizeof(int)));
|
||||
if (indices_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Gather Malloc indices_ error!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
(void)memset(indices_, 0, count * sizeof(int));
|
||||
for (int i = 0; i < count; ++i) {
|
||||
indices_[i] =
|
||||
static_cast<int>(round((indices_ptr[i] - ind_quant_args.front().zeroPoint) * ind_quant_args.front().scale));
|
||||
}
|
||||
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int GatherInt8CPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
int GatherInt8CPUKernel::DoGather(int task_id) {
|
||||
auto input_tensor = in_tensors_.at(0);
|
||||
auto indices_tensor = in_tensors_.at(1);
|
||||
auto out_tensor = out_tensors_.at(0);
|
||||
|
||||
auto input_ptr = reinterpret_cast<int8_t *>(input_tensor->Data());
|
||||
auto output_ptr = reinterpret_cast<int8_t *>(out_tensor->Data());
|
||||
|
||||
auto in_shape = input_tensor->shape();
|
||||
int in_rank = in_shape.size();
|
||||
int indices_element_size = indices_tensor->ElementsNum();
|
||||
|
||||
const int limit = in_shape[axis_];
|
||||
for (int i = 0; i < indices_element_size; ++i) {
|
||||
if (indices_[i] >= limit) {
|
||||
MS_LOG(ERROR) << " indice data: " << indices_[i] << " is not in [ 0, " << limit - 1 << " ]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
int outer_size = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
outer_size *= in_shape[i];
|
||||
}
|
||||
|
||||
int inner_size = 1;
|
||||
for (int i = axis_ + 1; i < in_rank; ++i) {
|
||||
inner_size *= in_shape[i];
|
||||
}
|
||||
|
||||
int stride = UP_DIV(outer_size, thread_count_);
|
||||
int count = MSMIN(stride, outer_size - stride * task_id);
|
||||
auto thread_stride = stride * task_id;
|
||||
|
||||
int error_code;
|
||||
input_ptr += thread_stride * limit;
|
||||
output_ptr += thread_stride * indices_element_size;
|
||||
error_code = GatherInt8(input_ptr, output_ptr, count, inner_size, limit, indices_, indices_element_size, param_);
|
||||
|
||||
if (error_code != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GatherInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
||||
auto gather_kernel = reinterpret_cast<GatherInt8CPUKernel *>(cdata);
|
||||
auto error_code = gather_kernel->DoGather(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "GatherRun error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GatherInt8CPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
int error_code = LiteBackendParallelLaunch(GatherInt8Run, this, thread_count_);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "Gather function error error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuGatherInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const lite::Context *ctx,
|
||||
const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_Gather);
|
||||
if (opParameter == nullptr) {
|
||||
MS_LOG(ERROR) << "input parameter is nullptr!";
|
||||
return nullptr;
|
||||
}
|
||||
auto *kernel = new (std::nothrow) GatherInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Gather, CpuGatherInt8KernelCreator)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GATHER_INT8_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GATHER_INT8_H_
|
||||
|
||||
#include <vector>
|
||||
#include "nnacl/gather_parameter.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "src/lite_kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class GatherInt8CPUKernel : public LiteKernel {
|
||||
public:
|
||||
GatherInt8CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {}
|
||||
~GatherInt8CPUKernel() {
|
||||
free(indices_);
|
||||
indices_ = nullptr;
|
||||
}
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoGather(int task_id);
|
||||
|
||||
private:
|
||||
int *indices_ = nullptr;
|
||||
int thread_count_;
|
||||
int batchDims_;
|
||||
int axis_;
|
||||
GatherQuantArg param_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GATHER_INT8_H_
|
|
@ -0,0 +1,101 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/nnacl/fp32/gatherNd.h"
|
||||
#include "mindspore/lite/nnacl/int8/gatherNd_int8.h"
|
||||
#include "mindspore/lite/src/kernel_registry.h"
|
||||
#include "mindspore/lite/src/lite_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestGatherNdInt8 : public mindspore::CommonTest {
|
||||
public:
|
||||
TestGatherNdInt8() {}
|
||||
};
|
||||
|
||||
TEST_F(TestGatherNdInt8, GatherNdTest) {
|
||||
std::vector<int8_t> in_data = {3, 5, 7, 9, 11, 13, 15, 17, 19, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 1};
|
||||
std::vector<int8_t> in_data1 = {2, 4, 4, 2, 2, 4, 2, 4, 2};
|
||||
// std::vector<int8_t> in_data1 = {2, 2, 2, 4};
|
||||
|
||||
std::vector<lite::tensor::Tensor *> inputs_tensor;
|
||||
std::vector<lite::tensor::Tensor *> outputs_tensor;
|
||||
|
||||
GatherNdParameter op_param;
|
||||
op_param.op_parameter_.type_ = schema::PrimitiveType_GatherNd;
|
||||
op_param.batchDims_ = 1;
|
||||
std::vector<int> shape = {1, 2, 2, 5};
|
||||
std::vector<int> out_shape = {1, 3, 5};
|
||||
|
||||
lite::tensor::QuantArg input_quant_arg;
|
||||
input_quant_arg.scale = 0.5;
|
||||
input_quant_arg.zeroPoint = 1;
|
||||
lite::tensor::QuantArg input_quant_arg_1;
|
||||
input_quant_arg_1.scale = 0.5;
|
||||
input_quant_arg_1.zeroPoint = 2;
|
||||
lite::tensor::QuantArg output_quant_arg;
|
||||
output_quant_arg.scale = 1;
|
||||
output_quant_arg.zeroPoint = 0;
|
||||
|
||||
lite::tensor::Tensor input0_tensor;
|
||||
lite::tensor::Tensor input1_tensor;
|
||||
|
||||
inputs_tensor.push_back(&input0_tensor);
|
||||
inputs_tensor.push_back(&input1_tensor);
|
||||
|
||||
input0_tensor.SetData(in_data.data());
|
||||
input1_tensor.SetData(in_data1.data());
|
||||
|
||||
input0_tensor.set_shape(shape);
|
||||
input1_tensor.set_shape({3, 3});
|
||||
|
||||
input0_tensor.AddQuantParam(input_quant_arg);
|
||||
input1_tensor.AddQuantParam(input_quant_arg_1);
|
||||
|
||||
std::vector<int8_t> output(15);
|
||||
// std::vector<int8_t> corr_out = {1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0};
|
||||
std::vector<int8_t> corr_out = {6, 7, 8, 9, 0, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5};
|
||||
lite::tensor::Tensor output0_tensor;
|
||||
outputs_tensor.push_back(&output0_tensor);
|
||||
output0_tensor.SetData(output.data());
|
||||
output0_tensor.set_shape(out_shape);
|
||||
output0_tensor.AddQuantParam(output_quant_arg);
|
||||
|
||||
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_GatherNd};
|
||||
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
||||
ASSERT_NE(creator, nullptr);
|
||||
lite::Context ctx;
|
||||
ctx.thread_num_ = 3;
|
||||
kernel::LiteKernel *kernel =
|
||||
creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(&op_param), &ctx, desc, nullptr);
|
||||
ASSERT_NE(kernel, nullptr);
|
||||
auto output_tensor_shape = output0_tensor.shape();
|
||||
kernel->Run();
|
||||
|
||||
printf("==================output data=================\n");
|
||||
for (int i = 0; i < output0_tensor.ElementsNum(); i++) {
|
||||
printf("%d, ", output[i]);
|
||||
}
|
||||
std::cout << std::endl;
|
||||
CompareOutputData(output.data(), corr_out.data(), output0_tensor.ElementsNum(), 0.001);
|
||||
|
||||
input0_tensor.SetData(nullptr);
|
||||
input1_tensor.SetData(nullptr);
|
||||
output0_tensor.SetData(nullptr);
|
||||
MS_LOG(INFO) << "TestGatherNd accuracy passed";
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,99 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/nnacl/gather_parameter.h"
|
||||
#include "mindspore/lite/nnacl/int8/gather_int8.h"
|
||||
#include "mindspore/lite/src/kernel_registry.h"
|
||||
#include "mindspore/lite/src/lite_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestGatherInt8 : public mindspore::CommonTest {
|
||||
public:
|
||||
TestGatherInt8() {}
|
||||
};
|
||||
|
||||
TEST_F(TestGatherInt8, GatherTest) {
|
||||
std::vector<int8_t> in_data = {11, 41, 21, 51, 31, 61, -11, -41, -21, -51, -31, -61};
|
||||
std::vector<int8_t> in_data1 = {4, 2};
|
||||
std::vector<lite::tensor::Tensor *> inputs_tensor;
|
||||
std::vector<lite::tensor::Tensor *> outputs_tensor;
|
||||
|
||||
GatherParameter op_param;
|
||||
op_param.op_parameter_.type_ = schema::PrimitiveType_Gather;
|
||||
op_param.axis_ = 0;
|
||||
op_param.batchDims_ = 1;
|
||||
std::vector<int> shape = {2, 1, 3, 2};
|
||||
|
||||
lite::tensor::QuantArg input_quant_arg;
|
||||
input_quant_arg.scale = 0.1;
|
||||
input_quant_arg.zeroPoint = 1;
|
||||
lite::tensor::QuantArg input_quant_arg_1;
|
||||
input_quant_arg_1.scale = 0.5;
|
||||
input_quant_arg_1.zeroPoint = 2;
|
||||
lite::tensor::QuantArg output_quant_arg;
|
||||
output_quant_arg.scale = 0.1;
|
||||
output_quant_arg.zeroPoint = 1;
|
||||
|
||||
lite::tensor::Tensor input0_tensor;
|
||||
lite::tensor::Tensor input1_tensor;
|
||||
|
||||
inputs_tensor.push_back(&input0_tensor);
|
||||
inputs_tensor.push_back(&input1_tensor);
|
||||
|
||||
input0_tensor.SetData(in_data.data());
|
||||
input1_tensor.SetData(in_data1.data());
|
||||
|
||||
input0_tensor.set_shape(shape);
|
||||
input1_tensor.set_shape({2});
|
||||
|
||||
input0_tensor.AddQuantParam(input_quant_arg);
|
||||
input1_tensor.AddQuantParam(input_quant_arg_1);
|
||||
|
||||
std::vector<int8_t> output(12);
|
||||
// std::vector<int8_t> corr_out = {-18, -22, -16, -21, -14, -19, -22, -34, -24, -35, -26, -36 };
|
||||
std::vector<int8_t> corr_out = {-11, -41, -21, -51, -31, -61, 11, 41, 21, 51, 31, 61};
|
||||
lite::tensor::Tensor output0_tensor;
|
||||
outputs_tensor.push_back(&output0_tensor);
|
||||
output0_tensor.SetData(output.data());
|
||||
output0_tensor.set_shape(shape);
|
||||
output0_tensor.AddQuantParam(output_quant_arg);
|
||||
|
||||
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Gather};
|
||||
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
||||
ASSERT_NE(creator, nullptr);
|
||||
lite::Context ctx;
|
||||
ctx.thread_num_ = 3;
|
||||
kernel::LiteKernel *kernel =
|
||||
creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(&op_param), &ctx, desc, nullptr);
|
||||
ASSERT_NE(kernel, nullptr);
|
||||
auto output_tensor_shape = output0_tensor.shape();
|
||||
kernel->Run();
|
||||
|
||||
printf("==================output data=================\n");
|
||||
for (int i = 0; i < output0_tensor.ElementsNum(); i++) {
|
||||
printf("%d, ", output[i]);
|
||||
}
|
||||
std::cout << std::endl;
|
||||
CompareOutputData(output.data(), corr_out.data(), output0_tensor.ElementsNum(), 0.001);
|
||||
|
||||
input0_tensor.SetData(nullptr);
|
||||
input1_tensor.SetData(nullptr);
|
||||
output0_tensor.SetData(nullptr);
|
||||
MS_LOG(INFO) << "TestGather_int8 accuracy passed";
|
||||
}
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue