forked from mindspore-Ecosystem/mindspore
!4335 Add arm op Div for int8 and testcases
Merge pull request !4335 from wangminggui/master
This commit is contained in:
commit
fb2f888ec8
|
@ -0,0 +1,194 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/int8/div_int8.h"
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Div;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int DivInt8CPUKernel::Init() {
|
||||
lite::tensor::Tensor *input0 = in_tensors_.at(0);
|
||||
lite::tensor::Tensor *input1 = in_tensors_.at(1);
|
||||
lite::tensor::Tensor *output = out_tensors_.at(0);
|
||||
MS_ASSERT(input0);
|
||||
MS_ASSERT(input1);
|
||||
MS_ASSERT(output);
|
||||
|
||||
broadcast_ = input0->ElementsNum() != input1->ElementsNum();
|
||||
|
||||
param_.in0_args_.scale_ = input0->GetQuantParams().front().scale;
|
||||
param_.in0_args_.zp_ = -input0->GetQuantParams().front().zeroPoint;
|
||||
param_.in1_args_.scale_ = input1->GetQuantParams().front().scale;
|
||||
param_.in1_args_.zp_ = -input1->GetQuantParams().front().zeroPoint;
|
||||
param_.out_args_.scale_ = output->GetQuantParams().front().scale;
|
||||
param_.out_args_.zp_ = output->GetQuantParams().front().zeroPoint;
|
||||
|
||||
const double real_multiplier = param_.in0_args_.scale_ / (param_.in1_args_.scale_ * param_.out_args_.scale_);
|
||||
|
||||
QuantizeMultiplier(real_multiplier, ¶m_.output_multiplier_, ¶m_.output_shift_);
|
||||
|
||||
param_.output_activation_min_ = std::numeric_limits<int8_t>::min();
|
||||
param_.output_activation_max_ = std::numeric_limits<int8_t>::max();
|
||||
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int DivInt8CPUKernel::ReSize() {
|
||||
if (broadcast_) {
|
||||
if (tile0_data_ != nullptr) {
|
||||
if (context_ != nullptr && context_->allocator != nullptr) {
|
||||
context_->allocator->Free(tile0_data_);
|
||||
} else {
|
||||
free(tile0_data_);
|
||||
}
|
||||
}
|
||||
if (tile1_data_ != nullptr) {
|
||||
if (context_ != nullptr && context_->allocator != nullptr) {
|
||||
context_->allocator->Free(tile1_data_);
|
||||
} else {
|
||||
free(tile1_data_);
|
||||
}
|
||||
}
|
||||
|
||||
if (context_ != nullptr && context_->allocator != nullptr) {
|
||||
tile0_data_ = static_cast<int8_t *>(context_->allocator->Malloc(out_tensors_.at(0)->Size()));
|
||||
tile1_data_ = static_cast<int8_t *>(context_->allocator->Malloc(out_tensors_.at(0)->Size()));
|
||||
} else {
|
||||
tile0_data_ = static_cast<int8_t *>(malloc(sizeof(int8_t) * out_tensors_.at(0)->Size()));
|
||||
tile1_data_ = static_cast<int8_t *>(malloc(sizeof(int8_t) * out_tensors_.at(0)->Size()));
|
||||
}
|
||||
|
||||
if (tile0_data_ == nullptr || tile1_data_ == nullptr) {
|
||||
if (tile0_data_ != nullptr) {
|
||||
if (context_ != nullptr && context_->allocator != nullptr) {
|
||||
context_->allocator->Free(tile0_data_);
|
||||
} else {
|
||||
free(tile0_data_);
|
||||
}
|
||||
}
|
||||
if (tile1_data_ != nullptr) {
|
||||
if (context_ != nullptr && context_->allocator != nullptr) {
|
||||
context_->allocator->Free(tile1_data_);
|
||||
} else {
|
||||
free(tile1_data_);
|
||||
}
|
||||
}
|
||||
MS_LOG(ERROR) << "malloc memroy fail!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int DivInt8CPUKernel::DoExecute(int task_id) {
|
||||
auto input0_data_ = static_cast<int8_t *>(in_tensors_.at(0)->Data());
|
||||
auto input1_data_ = static_cast<int8_t *>(in_tensors_.at(1)->Data());
|
||||
auto output_data_ = static_cast<int8_t *>(out_tensors_.at(0)->Data());
|
||||
auto element_num = out_tensors_[0]->ElementsNum();
|
||||
|
||||
MS_ASSERT(op_parameter_->thread_num_ != 0);
|
||||
int stride = UP_DIV(element_num, op_parameter_->thread_num_);
|
||||
int count = MSMIN(stride, element_num - stride * task_id);
|
||||
|
||||
auto ret = RET_OK;
|
||||
if (broadcast_) {
|
||||
ret = DivInt8(tile0_data_ + task_id * count, tile1_data_ + task_id * count, output_data_ + task_id * count, count,
|
||||
¶m_);
|
||||
} else {
|
||||
ret = DivInt8(input0_data_ + task_id * count, input1_data_ + task_id * count, output_data_ + task_id * count, count,
|
||||
¶m_);
|
||||
}
|
||||
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Divint8 function error error_code[" << ret << "]";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int DivInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
||||
auto div_kernel = reinterpret_cast<DivInt8CPUKernel *>(cdata);
|
||||
auto ret = div_kernel->DoExecute(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DivInt8 DoExecute error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int DivInt8CPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (broadcast_) {
|
||||
ArithmeticParameter tile_para = {0};
|
||||
tile_para.ndim_ = out_tensors_.at(0)->shape().size();
|
||||
for (size_t i = 0; i < tile_para.ndim_; i++) {
|
||||
tile_para.in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i);
|
||||
tile_para.in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i);
|
||||
tile_para.out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
|
||||
}
|
||||
TileDimensionsUint8(static_cast<uint8_t *>(in_tensors_.at(0)->Data()),
|
||||
static_cast<uint8_t *>(in_tensors_.at(1)->Data()), reinterpret_cast<uint8_t *>(tile0_data_),
|
||||
reinterpret_cast<uint8_t *>(tile1_data_), &tile_para);
|
||||
}
|
||||
ret = LiteBackendParallelLaunch(DivInt8Run, this, op_parameter_->thread_num_);
|
||||
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DivInt8Run function error error_code[" << ret << "]";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuDivInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, OpParameter *parameter,
|
||||
const lite::Context *ctx, const KernelKey &desc,
|
||||
const lite::Primitive *primitive) {
|
||||
if (parameter == nullptr || ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "parameter or ctx is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
MS_ASSERT(desc.type == PrimitiveType_Div);
|
||||
auto *kernel = new (std::nothrow) DivInt8CPUKernel(parameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
|
||||
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Div, CpuDivInt8KernelCreator)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DIV_INT8_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DIV_INT8_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/int8/div_int8.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class DivInt8CPUKernel : public LiteKernel {
|
||||
public:
|
||||
explicit DivInt8CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
||||
const lite::Primitive *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~DivInt8CPUKernel() override {}
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoExecute(int task_id);
|
||||
|
||||
private:
|
||||
DivQuantArg param_;
|
||||
int8_t *tile0_data_ = nullptr;
|
||||
int8_t *tile1_data_ = nullptr;
|
||||
bool broadcast_ = false;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DIV_INT8_H_
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/int8/div_int8.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int DivInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, DivQuantArg *para) {
|
||||
int index = 0;
|
||||
for (; index < real_dst_count; ++index) {
|
||||
const int32_t input0_val = para->in0_args_.zp_ + input0_data[index];
|
||||
const int32_t input1_val = para->in1_args_.zp_ + input1_data[index];
|
||||
if (input1_val == 0) {
|
||||
return NNACL_ERRCODE_DIVISOR_ZERO;
|
||||
}
|
||||
|
||||
int recip_shift;
|
||||
const int32_t input1_inv = (input1_val > 0) ? ComputerReciproal(input1_val, 31, &recip_shift)
|
||||
: -ComputerReciproal(-input1_val, 31, &recip_shift);
|
||||
const int leading_bits = CountLeadingSignBits(input0_val);
|
||||
const int32_t raw_data =
|
||||
SaturatingRoundingDoublingHighMul(input0_val * (1 << (unsigned int)leading_bits), input1_inv);
|
||||
const int total_shift = para->output_shift_ - recip_shift - leading_bits;
|
||||
const int32_t raw_output =
|
||||
RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(raw_data, para->output_multiplier_), -total_shift) +
|
||||
para->out_args_.zp_;
|
||||
output_data[index] = (int8_t)MSMAX(para->output_activation_min_, MSMIN(raw_output, para->output_activation_max_));
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_DIV_INT8_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_DIV_INT8_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int DivInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, DivQuantArg *para);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_DIV_INT8_H_
|
|
@ -64,6 +64,114 @@ inline int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int3
|
|||
return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift);
|
||||
}
|
||||
|
||||
inline int FractionsBits(int kIntegerBits) {
|
||||
int totalBits = 8 * sizeof(int32_t) - 1;
|
||||
return totalBits - kIntegerBits;
|
||||
}
|
||||
|
||||
inline int FixedPoint_One(int kIntegerBits, int kFractionsBits) {
|
||||
return (kIntegerBits == 0 ? INT32_MAX : ((1) << (uint32_t)(kIntegerBits == 0 ? 0 : kFractionsBits)));
|
||||
}
|
||||
|
||||
inline int RoundingHalfSum(int a, int b) {
|
||||
int64_t a64 = a;
|
||||
int64_t b64 = b;
|
||||
int64_t sum = a64 + b64;
|
||||
int64_t sign = sum > 0 ? 1 : -1;
|
||||
return (int32_t)((sum + sign) / 2);
|
||||
}
|
||||
|
||||
inline int32_t BitAnd(int32_t a, int32_t b) { return (uint32_t)a & (uint32_t)b; }
|
||||
|
||||
inline int32_t BitOr(int32_t a, int32_t b) { return (uint32_t)a | (uint32_t)b; }
|
||||
|
||||
inline int32_t BitXor(int32_t a, int32_t b) { return (uint32_t)a ^ (uint32_t)b; }
|
||||
|
||||
inline int32_t BitNot(int32_t a) { return ~(uint32_t)a; }
|
||||
|
||||
inline int SelectUsingMask(int mask, int bound, int val) {
|
||||
return BitXor(BitAnd(mask, bound), BitAnd(BitNot(mask), val));
|
||||
}
|
||||
|
||||
inline int32_t MaskNonZero(int32_t a) {
|
||||
int32_t zreo = 0;
|
||||
return a ? BitNot(zreo) : zreo;
|
||||
}
|
||||
|
||||
inline int SaturatingRoundingMultiplyByPOT(int32_t x, int Exponent) {
|
||||
int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0);
|
||||
if (ExponentSign == 0) {
|
||||
return x;
|
||||
} else if (ExponentSign == 1) {
|
||||
const int min = INT32_MIN;
|
||||
const int max = INT32_MAX;
|
||||
const int thresold = ((1 << (uint32_t)(31 - Exponent)) - 1);
|
||||
const int postive_mask = MaskNonZero(x > thresold);
|
||||
const int negative_mask = MaskNonZero(x < -thresold);
|
||||
int result = x << Exponent;
|
||||
result = SelectUsingMask(postive_mask, max, result);
|
||||
result = SelectUsingMask(negative_mask, min, result);
|
||||
return result;
|
||||
} else if (ExponentSign == -1) {
|
||||
return RoundingDivideByPOT(x, -Exponent);
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
inline int32_t Rescale(int x, int kIntegerBitsSrc, int kIntegerBitsDst) {
|
||||
int kExponent = kIntegerBitsSrc - kIntegerBitsDst;
|
||||
int result = SaturatingRoundingMultiplyByPOT(x, kExponent);
|
||||
return result;
|
||||
}
|
||||
|
||||
static inline int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a) {
|
||||
int one = FixedPoint_One(0, FractionsBits(0));
|
||||
int half_denominator = RoundingHalfSum(a, one);
|
||||
const int constant_48_over_17 = 1515870810;
|
||||
const int constant_neg_32_over_17 = -1010580540;
|
||||
int x = constant_48_over_17 + SaturatingRoundingDoublingHighMul(half_denominator, constant_neg_32_over_17);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
int half_denominator_times_x = SaturatingRoundingDoublingHighMul(half_denominator, x);
|
||||
int one_minus_half_denominator_times_x = FixedPoint_One(2, FractionsBits(2)) - half_denominator_times_x;
|
||||
x = x + Rescale(SaturatingRoundingDoublingHighMul(x, one_minus_half_denominator_times_x), 2 + 2, 2);
|
||||
}
|
||||
return Rescale(x, 2 - 1, 0);
|
||||
}
|
||||
|
||||
inline int CountLeadingZeroBits(uint32_t x) {
|
||||
#if defined(__GUNC__)
|
||||
return x ? __builtin_clz(x) : 8 * sizeof(uint32_t);
|
||||
#else
|
||||
if (x == 0) {
|
||||
return 8 * sizeof(uint32_t);
|
||||
}
|
||||
const int32_t leading_positive = (int32_t)(1) << (8 * sizeof(uint32_t) - 1);
|
||||
int leading_zeros = 0;
|
||||
while (x < leading_positive) {
|
||||
x <<= 1;
|
||||
leading_zeros++;
|
||||
}
|
||||
return leading_zeros;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline int CountLeadingSignBits(int32_t x) {
|
||||
#if defined(__GUNC__) && !defined(__clang__)
|
||||
return x ? __builtin_clrsb(x) : 8 * sizeof(int32_t);
|
||||
#else
|
||||
return x >= 0 ? CountLeadingZeroBits((uint32_t)x) - 1 : x != INT32_MIN ? CountLeadingZeroBits(2 * (uint32_t)(-x)) : 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline int32_t ComputerReciproal(int32_t x, int x_digits, int *recip_shift) {
|
||||
int leading_zreos_plus_one = CountLeadingZeroBits((uint32_t)x);
|
||||
*recip_shift = x_digits - leading_zreos_plus_one;
|
||||
const int32_t shifted_minus_one = (int32_t)(((uint32_t)x << leading_zreos_plus_one) - ((uint32_t)(1) << 31));
|
||||
const int32_t shifted_scaled = one_over_one_plus_x_for_x_in_0_1(shifted_minus_one);
|
||||
return shifted_scaled;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -197,6 +197,16 @@ typedef struct ArithmeticQuantArg {
|
|||
QuantArg in1_args_;
|
||||
QuantArg out_args_;
|
||||
} ArithmeticQuantArg;
|
||||
|
||||
typedef struct DivQuantArg {
|
||||
QuantArg in0_args_;
|
||||
QuantArg in1_args_;
|
||||
QuantArg out_args_;
|
||||
int output_activation_min_;
|
||||
int output_activation_max_;
|
||||
int output_multiplier_;
|
||||
int output_shift_;
|
||||
} DivQuantArg;
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/arm/int8/div_int8.h"
|
||||
#include "mindspore/lite/src/kernel_registry.h"
|
||||
#include "mindspore/lite/include/context.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestDivInt8 : public mindspore::CommonTest {
|
||||
public:
|
||||
TestDivInt8() {}
|
||||
};
|
||||
|
||||
TEST_F(TestDivInt8, DivInt8) {
|
||||
lite::tensor::Tensor in_tensor0(kNumberTypeInt8, {1, 1, 2, 5});
|
||||
lite::tensor::Tensor in_tensor1(kNumberTypeInt8, {1, 1, 2, 5});
|
||||
lite::tensor::Tensor out_tensor(kNumberTypeInt8, {1, 1, 2, 5});
|
||||
|
||||
int8_t input_data0[] = {105, 35, -27, 0, -63, 99, 16, 45, 67, -49};
|
||||
int8_t input_data1[] = {126, -38, -115, 106, -98, 119, 103, 81, -114, 68};
|
||||
int8_t output_data[10] = {0};
|
||||
in_tensor0.SetData(input_data0);
|
||||
in_tensor1.SetData(input_data1);
|
||||
out_tensor.SetData(output_data);
|
||||
|
||||
const lite::tensor::QuantArg quant_in0 = {0.00784314f, 0}; // -1.0--1.0 -> 0--255
|
||||
const lite::tensor::QuantArg quant_in1 = {0.00784314f, 0};
|
||||
const lite::tensor::QuantArg quant_out = {0.00784314f, 0};
|
||||
in_tensor0.AddQuantParam(quant_in0);
|
||||
in_tensor1.AddQuantParam(quant_in1);
|
||||
out_tensor.AddQuantParam(quant_out);
|
||||
|
||||
std::vector<lite::tensor::Tensor *> inputs = {&in_tensor0, &in_tensor1};
|
||||
std::vector<lite::tensor::Tensor *> outputs = {&out_tensor};
|
||||
|
||||
OpParameter parameter = {};
|
||||
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Div};
|
||||
|
||||
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
||||
ASSERT_NE(creator, nullptr);
|
||||
|
||||
auto ctx = std::make_shared<lite::Context>();
|
||||
auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(¶meter), ctx.get(), desc, nullptr);
|
||||
ASSERT_NE(kernel, nullptr);
|
||||
|
||||
auto ret = kernel->Run();
|
||||
EXPECT_EQ(0, ret);
|
||||
|
||||
int8_t expect0[10] = {106, -117, 30, 0, 82, 106, 20, 71, -75, -92};
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
EXPECT_EQ(output_data[i], expect0[i]);
|
||||
}
|
||||
|
||||
in_tensor0.SetData(nullptr);
|
||||
in_tensor1.SetData(nullptr);
|
||||
out_tensor.SetData(nullptr);
|
||||
}
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue