fix precision problem of sub int8 lite ops

This commit is contained in:
liuwenhao4 2020-10-24 14:24:39 +08:00
parent d2b1e783e7
commit 917907b70b
3 changed files with 57 additions and 4 deletions

View File

@ -70,5 +70,6 @@ PrimitiveC *InstanceNormCreator(const schema::Primitive *primitive) {
} }
Registry InstanceNormRegistry(schema::PrimitiveType_InstanceNorm, InstanceNormCreator); Registry InstanceNormRegistry(schema::PrimitiveType_InstanceNorm, InstanceNormCreator);
#endif #endif
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -92,14 +92,17 @@ int SubInt8CPUKernel::DoExecute(int task_id) {
MS_ASSERT(op_parameter_->thread_num_ != 0); MS_ASSERT(op_parameter_->thread_num_ != 0);
int stride = UP_DIV(element_num, op_parameter_->thread_num_); int stride = UP_DIV(element_num, op_parameter_->thread_num_);
int count = MSMIN(stride, element_num - stride * task_id); int count = MSMIN(stride, element_num - stride * task_id);
if (count <= 0) {
return RET_OK;
}
auto ret = RET_OK; auto ret = RET_OK;
if (broadcast_) { if (broadcast_) {
ret = SubInt8(tile0_data_ + task_id * count, tile1_data_ + task_id * count, output_data_ + task_id * count, count, ret = SubInt8(tile0_data_ + task_id * stride, tile1_data_ + task_id * stride, output_data_ + task_id * stride,
&param_); count, &param_);
} else { } else {
ret = SubInt8(input0_data_ + task_id * count, input1_data_ + task_id * count, output_data_ + task_id * count, count, ret = SubInt8(input0_data_ + task_id * stride, input1_data_ + task_id * stride, output_data_ + task_id * stride,
&param_); count, &param_);
} }
if (ret != RET_OK) { if (ret != RET_OK) {

View File

@ -57,6 +57,7 @@ TEST_F(TestSubInt8, SubInt8) {
ASSERT_NE(creator, nullptr); ASSERT_NE(creator, nullptr);
auto ctx = std::make_shared<lite::InnerContext>(); auto ctx = std::make_shared<lite::InnerContext>();
ctx->thread_num_ = 1;
ASSERT_EQ(lite::RET_OK, ctx->Init()); ASSERT_EQ(lite::RET_OK, ctx->Init());
auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(&parameter), ctx.get(), desc, nullptr); auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(&parameter), ctx.get(), desc, nullptr);
ASSERT_NE(kernel, nullptr); ASSERT_NE(kernel, nullptr);
@ -73,4 +74,52 @@ TEST_F(TestSubInt8, SubInt8) {
in_tensor1.set_data(nullptr); in_tensor1.set_data(nullptr);
out_tensor.set_data(nullptr); out_tensor.set_data(nullptr);
} }
TEST_F(TestSubInt8, SubInt8T2) {
lite::Tensor in_tensor0(kNumberTypeInt8, {1, 1, 2, 5});
lite::Tensor in_tensor1(kNumberTypeInt8, {1, 1, 1, 5});
lite::Tensor out_tensor(kNumberTypeInt8, {1, 1, 2, 5});
int8_t input_data0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
int8_t input_data1[] = {1, 2, 3, 4, 5};
int8_t output_data[10] = {0};
in_tensor0.set_data(input_data0);
in_tensor1.set_data(input_data1);
out_tensor.set_data(output_data);
const lite::QuantArg quant_in0 = {0.00784314f, 0}; // -1.0--1.0 -> 0--255
const lite::QuantArg quant_in1 = {0.00784314f, 0};
const lite::QuantArg quant_out = {0.00784314f, 0};
in_tensor0.AddQuantParam(quant_in0);
in_tensor1.AddQuantParam(quant_in1);
out_tensor.AddQuantParam(quant_out);
std::vector<lite::Tensor *> inputs = {&in_tensor0, &in_tensor1};
std::vector<lite::Tensor *> outputs = {&out_tensor};
OpParameter parameter = {};
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Sub};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
auto ctx = std::make_shared<lite::InnerContext>();
ctx->thread_num_ = 2;
ASSERT_EQ(lite::RET_OK, ctx->Init());
auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(&parameter), ctx.get(), desc, nullptr);
ASSERT_NE(kernel, nullptr);
auto ret = kernel->Run();
EXPECT_EQ(0, ret);
int8_t expect0[10] = {0, 0, 0, 0, 0, 5, 5, 5, 5, 5};
for (int i = 0; i < 10; ++i) {
EXPECT_EQ(output_data[i], expect0[i]);
}
PrintData("output data", output_data, 10);
in_tensor0.set_data(nullptr);
in_tensor1.set_data(nullptr);
out_tensor.set_data(nullptr);
}
} // namespace mindspore } // namespace mindspore