forked from mindspore-Ecosystem/mindspore
!25695 [MSLITE] Fix bug and optimize reduce mean op in micro
Merge pull request !25695 from zhanyuan/reduce_micro
This commit is contained in:
commit
ae1ebf1430
|
@ -88,8 +88,6 @@ int ReduceBaseCoder::Init() {
|
|||
"memcpy_s failed!");
|
||||
}
|
||||
mode_ = reduce_param->mode_;
|
||||
MS_CHECK_RET_CODE(memcpy_s(axes_, sizeof(axes_), reduce_param->axes_, sizeof(reduce_param->axes_)),
|
||||
"memcpy_s failed!");
|
||||
reduce_to_end_ = reduce_param->reduce_to_end_;
|
||||
MS_CHECK_RET_CODE(CheckInputsOutputs(), "CheckInputsOutputs failed!");
|
||||
return RET_OK;
|
||||
|
|
|
@ -23,20 +23,16 @@
|
|||
|
||||
using mindspore::schema::PrimitiveType_ReduceFusion;
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
int ReduceInt8Coder::CalculateQuantArgs() {
|
||||
LiteQuantParam input_quant = input_tensor_->quant_params().at(0);
|
||||
LiteQuantParam output_quant = output_tensor_->quant_params().at(0);
|
||||
quant_arg_.in_scale_ = input_quant.scale;
|
||||
quant_arg_.in_zp_ = input_quant.zeroPoint;
|
||||
quant_arg_.out_scale_ = output_quant.scale;
|
||||
quant_arg_.out_zp_ = output_quant.zeroPoint;
|
||||
const double input_output_multiplier = quant_arg_.in_scale_ / quant_arg_.out_scale_;
|
||||
int ReduceInt8Coder::CalReduceMeanQuantParam() {
|
||||
int shift;
|
||||
QuantizeMultiplierSmallerThanOne(input_output_multiplier, &quant_arg_.in_out_multiplier_, &shift);
|
||||
quant_arg_.in_out_left_shift_ = shift < 0 ? -shift : 0;
|
||||
quant_arg_.in_out_right_shift_ = shift > 0 ? shift : 0;
|
||||
MS_CHECK_TRUE(num_axes_ < MAX_SHAPE_SIZE, "the number of axes should be less the max num");
|
||||
if (mode_ == static_cast<int>(schema::ReduceMode_ReduceMean)) {
|
||||
if (axes_hw_pattern_) {
|
||||
int reduce_num = input_tensor_->shape()[1] * input_tensor_->shape()[2];
|
||||
bias_ = quant_arg_.out_zp_ - quant_arg_.in_zp_ * quant_arg_.in_scale_ / quant_arg_.out_scale_;
|
||||
double reciprocal = quant_arg_.in_scale_ / (quant_arg_.out_scale_ * reduce_num);
|
||||
QuantizeMultiplierSmallerThanOne(reciprocal, &reduce_mean_quant_param_.multiplier_, &shift);
|
||||
reduce_mean_quant_param_.left_shift_ = shift < 0 ? -shift : 0;
|
||||
reduce_mean_quant_param_.right_shift_ = shift > 0 ? shift : 0;
|
||||
} else {
|
||||
for (int i = 0; i < num_axes_; ++i) {
|
||||
auto axis = axes_[i];
|
||||
std::vector<int> in_shape = input_tensor_->shape();
|
||||
|
@ -53,49 +49,84 @@ int ReduceInt8Coder::CalculateQuantArgs() {
|
|||
mean_multipliers_.push_back(qm);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
if (mode_ == static_cast<int>(schema::ReduceMode_ReduceProd)) {
|
||||
for (int i = 0; i < num_axes_; ++i) {
|
||||
int axis = axes_[i];
|
||||
std::vector<int> in_shape = input_tensors_.at(kInputIndex)->shape();
|
||||
if (static_cast<int>(in_shape.size()) - 1 < axis) {
|
||||
MS_LOG(ERROR) << "input tensor shape is invalid";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int axis_size = in_shape.at(axis);
|
||||
double prod_multiplier = std::pow(quant_arg_.in_scale_, axis_size - 1);
|
||||
auto *qm = new (std::nothrow) QuantMulArg;
|
||||
MS_CHECK_PTR(qm);
|
||||
QuantizeMultiplierSmallerThanOne(prod_multiplier, &qm->multiplier_, &shift);
|
||||
qm->left_shift_ = shift < 0 ? -shift : 0;
|
||||
qm->right_shift_ = shift > 0 ? shift : 0;
|
||||
prod_multipliers_.push_back(qm);
|
||||
int ReduceInt8Coder::CalReduceProdQuantParam() {
|
||||
int shift;
|
||||
for (int i = 0; i < num_axes_; ++i) {
|
||||
int axis = axes_[i];
|
||||
std::vector<int> in_shape = input_tensors_.at(kInputIndex)->shape();
|
||||
if (static_cast<int>(in_shape.size()) - 1 < axis) {
|
||||
MS_LOG(ERROR) << "input tensor shape is invalid";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
if (mode_ == static_cast<int>(schema::ReduceMode_ReduceSumSquare)) {
|
||||
for (int i = 0; i < num_axes_ - 1; ++i) {
|
||||
auto *qm = new (std::nothrow) QuantMulArg;
|
||||
MS_CHECK_PTR(qm);
|
||||
double sum_square_multiplier = quant_arg_.in_scale_;
|
||||
QuantizeMultiplierSmallerThanOne(sum_square_multiplier, &qm->multiplier_, &shift);
|
||||
qm->left_shift_ = shift < 0 ? -shift : 0;
|
||||
qm->right_shift_ = shift > 0 ? shift : 0;
|
||||
sum_square_multipliers_.push_back(qm);
|
||||
}
|
||||
// for last num_axes
|
||||
int axis_size = in_shape.at(axis);
|
||||
double prod_multiplier = std::pow(quant_arg_.in_scale_, axis_size - 1);
|
||||
auto *qm = new (std::nothrow) QuantMulArg;
|
||||
MS_CHECK_PTR(qm);
|
||||
double sum_square_multiplier = quant_arg_.in_scale_ * (quant_arg_.in_scale_ / quant_arg_.out_scale_);
|
||||
QuantizeMultiplierSmallerThanOne(prod_multiplier, &qm->multiplier_, &shift);
|
||||
qm->left_shift_ = shift < 0 ? -shift : 0;
|
||||
qm->right_shift_ = shift > 0 ? shift : 0;
|
||||
prod_multipliers_.push_back(qm);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ReduceInt8Coder::CalReduceSumSquareQuantParam() {
|
||||
int shift;
|
||||
for (int i = 0; i < num_axes_ - 1; ++i) {
|
||||
auto *qm = new (std::nothrow) QuantMulArg;
|
||||
MS_CHECK_PTR(qm);
|
||||
double sum_square_multiplier = quant_arg_.in_scale_;
|
||||
QuantizeMultiplierSmallerThanOne(sum_square_multiplier, &qm->multiplier_, &shift);
|
||||
qm->left_shift_ = shift < 0 ? -shift : 0;
|
||||
qm->right_shift_ = shift > 0 ? shift : 0;
|
||||
sum_square_multipliers_.push_back(qm);
|
||||
}
|
||||
|
||||
// for last num_axes
|
||||
auto *qm = new (std::nothrow) QuantMulArg;
|
||||
MS_CHECK_PTR(qm);
|
||||
double sum_square_multiplier = quant_arg_.in_scale_ * (quant_arg_.in_scale_ / quant_arg_.out_scale_);
|
||||
QuantizeMultiplierSmallerThanOne(sum_square_multiplier, &qm->multiplier_, &shift);
|
||||
qm->left_shift_ = shift < 0 ? -shift : 0;
|
||||
qm->right_shift_ = shift > 0 ? shift : 0;
|
||||
sum_square_multipliers_.push_back(qm);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ReduceInt8Coder::CalculateQuantArgs() {
|
||||
LiteQuantParam input_quant = input_tensor_->quant_params().at(0);
|
||||
LiteQuantParam output_quant = output_tensor_->quant_params().at(0);
|
||||
quant_arg_.in_scale_ = input_quant.scale;
|
||||
quant_arg_.in_zp_ = input_quant.zeroPoint;
|
||||
quant_arg_.out_scale_ = output_quant.scale;
|
||||
quant_arg_.out_zp_ = output_quant.zeroPoint;
|
||||
const double input_output_multiplier = quant_arg_.in_scale_ / quant_arg_.out_scale_;
|
||||
int shift;
|
||||
QuantizeMultiplierSmallerThanOne(input_output_multiplier, &quant_arg_.in_out_multiplier_, &shift);
|
||||
quant_arg_.in_out_left_shift_ = shift < 0 ? -shift : 0;
|
||||
quant_arg_.in_out_right_shift_ = shift > 0 ? shift : 0;
|
||||
MS_CHECK_TRUE(num_axes_ < MAX_SHAPE_SIZE, "the number of axes should be less the max num");
|
||||
int ret = RET_OK;
|
||||
switch (mode_) {
|
||||
case static_cast<int>(schema::ReduceMode_ReduceMean):
|
||||
ret = CalReduceMeanQuantParam();
|
||||
break;
|
||||
case static_cast<int>(schema::ReduceMode_ReduceProd):
|
||||
ret = CalReduceProdQuantParam();
|
||||
break;
|
||||
case static_cast<int>(schema::ReduceMode_ReduceSumSquare):
|
||||
ret = CalReduceSumSquareQuantParam();
|
||||
break;
|
||||
default:
|
||||
ret = RET_ERROR;
|
||||
MS_LOG(ERROR) << "Reduce mode not currently supported: " << mode_;
|
||||
break;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ReduceInt8Coder::MallocTmpBuffer() {
|
||||
data_buffers_.clear();
|
||||
if (num_axes_ != static_cast<int>(buffer_sizes_.size())) {
|
||||
|
@ -136,6 +167,9 @@ void ReduceInt8Coder::GetQuantArgs(size_t index) {
|
|||
|
||||
int ReduceInt8Coder::Prepare(CoderContext *const context) {
|
||||
MS_CHECK_RET_CODE(ReduceBaseCoder::Init(), "Init failed");
|
||||
if (input_tensor_->shape().size() == DIMENSION_4D && num_axes_ == 2 && (axes_[0] + axes_[1]) == 3) {
|
||||
axes_hw_pattern_ = true;
|
||||
}
|
||||
std::vector<int> in_shape = input_tensor_->shape();
|
||||
if (!in_shape.empty()) {
|
||||
this->valid_shape_ = true;
|
||||
|
@ -182,56 +216,84 @@ int ReduceInt8Coder::Prepare(CoderContext *const context) {
|
|||
if (!this->valid_shape_) {
|
||||
MS_CHECK_RET_CODE(CalculateQuantArgs(), "CalculateQuantArgs failed");
|
||||
}
|
||||
MS_CHECK_RET_CODE(MallocTmpBuffer(), "MallocTmpBuffer failed");
|
||||
begin_src_data_ = static_cast<int32_t *>(
|
||||
allocator_->Malloc(kNumberTypeInt32, sizeof(int32_t) * input_tensor_->ElementsNum(), kWorkspace));
|
||||
MS_CHECK_PTR(begin_src_data_);
|
||||
if (axes_hw_pattern_) {
|
||||
nchw_in_data_ = static_cast<int8_t *>(
|
||||
allocator_->Malloc(kNumberTypeInt8, sizeof(int8_t) * input_tensor_->ElementsNum(), kWorkspace));
|
||||
MS_CHECK_PTR(nchw_in_data_);
|
||||
} else {
|
||||
MS_CHECK_RET_CODE(MallocTmpBuffer(), "MallocTmpBuffer failed");
|
||||
begin_src_data_ = static_cast<int32_t *>(
|
||||
allocator_->Malloc(kNumberTypeInt32, sizeof(int32_t) * input_tensor_->ElementsNum(), kWorkspace));
|
||||
MS_CHECK_PTR(begin_src_data_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ReduceInt8Coder::DoCode(CoderContext *const context) {
|
||||
MS_LOG(DEBUG) << "*****Reduce code start*****";
|
||||
NNaclInt8Serializer code;
|
||||
Collect(context,
|
||||
{
|
||||
"nnacl/int8/reduce_int8.h",
|
||||
},
|
||||
{
|
||||
"reduce_int8.c",
|
||||
"fixed_point.c",
|
||||
});
|
||||
std::string src_addr = allocator_->GetRuntimeAddr(input_tensor_);
|
||||
std::string dst_addr;
|
||||
std::string begin_src_data_src = allocator_->GetRuntimeAddr(begin_src_data_);
|
||||
if (axes_hw_pattern_) {
|
||||
Collect(context,
|
||||
{
|
||||
"nnacl/int8/pack_int8.h",
|
||||
"nnacl/int8/reduce_int8.h",
|
||||
},
|
||||
{
|
||||
"pack_int8.c",
|
||||
"reduce_int8.c",
|
||||
"fixed_point.c",
|
||||
});
|
||||
std::string input_origin = allocator_->GetRuntimeAddr(input_tensor_);
|
||||
std::string input_nchw = allocator_->GetRuntimeAddr(nchw_in_data_);
|
||||
std::string output = allocator_->GetRuntimeAddr(output_tensor_);
|
||||
int n = input_tensor_->Batch();
|
||||
int plane = input_tensor_->Height() * input_tensor_->Width();
|
||||
int c = input_tensor_->Channel();
|
||||
code.CodeFunction("PackNHWCToNCHWInt8", input_origin, input_nchw, n, plane, c);
|
||||
std::string quant_param = "quant_param";
|
||||
code.CodeStruct(quant_param, reduce_mean_quant_param_);
|
||||
code.CodeFunction("ReduceMeanHW", n, plane, c, c, input_nchw, output, quant_param, bias_);
|
||||
} else {
|
||||
Collect(context,
|
||||
{
|
||||
"nnacl/int8/reduce_int8.h",
|
||||
},
|
||||
{
|
||||
"reduce_int8.c",
|
||||
"fixed_point.c",
|
||||
});
|
||||
std::string src_addr = allocator_->GetRuntimeAddr(input_tensor_);
|
||||
std::string dst_addr;
|
||||
std::string begin_src_data_src = allocator_->GetRuntimeAddr(begin_src_data_);
|
||||
|
||||
code << "int *begin_data = (int *)(" << begin_src_data_src << ");\n";
|
||||
code << "int8_t *ori_data = (int8_t *)(" << src_addr << ");\n";
|
||||
code << "for (int i = 0; i < " << input_tensor_->ElementsNum() << "; ++i) {\n"
|
||||
<< " begin_data[i] = (int)ori_data[i];\n"
|
||||
<< " }\n";
|
||||
for (int i = 0; i < num_axes_; ++i) {
|
||||
GetQuantArgs(i);
|
||||
std::string quant_arg_i = "quant_arg_" + std::to_string(i);
|
||||
std::string ptr_quan_arg_i = "&" + quant_arg_i;
|
||||
code.CodeStruct(quant_arg_i, quant_arg_);
|
||||
if (i != num_axes_ - 1) {
|
||||
is_last_axis = false;
|
||||
dst_addr = allocator_->GetRuntimeAddr(data_buffers_.at(i));
|
||||
} else {
|
||||
is_last_axis = true;
|
||||
dst_addr = allocator_->GetRuntimeAddr(output_tensor_);
|
||||
code << "int *begin_data = (int *)(" << begin_src_data_src << ");\n";
|
||||
code << "int8_t *ori_data = (int8_t *)(" << src_addr << ");\n";
|
||||
code << "for (int i = 0; i < " << input_tensor_->ElementsNum() << "; ++i) {\n"
|
||||
<< " begin_data[i] = (int)ori_data[i];\n"
|
||||
<< " }\n";
|
||||
for (int i = 0; i < num_axes_; ++i) {
|
||||
GetQuantArgs(i);
|
||||
std::string quant_arg_i = "quant_arg_" + std::to_string(i);
|
||||
std::string ptr_quan_arg_i = "&" + quant_arg_i;
|
||||
code.CodeStruct(quant_arg_i, quant_arg_);
|
||||
if (i != num_axes_ - 1) {
|
||||
is_last_axis = false;
|
||||
dst_addr = allocator_->GetRuntimeAddr(data_buffers_.at(i));
|
||||
} else {
|
||||
is_last_axis = true;
|
||||
dst_addr = allocator_->GetRuntimeAddr(output_tensor_);
|
||||
}
|
||||
outer_size_ = outer_sizes_.at(i);
|
||||
inner_size_ = inner_sizes_.at(i);
|
||||
axis_size_ = axis_sizes_.at(i);
|
||||
if (!is_last_axis) {
|
||||
code.CodeFunction(reducer_, outer_size_, inner_size_, axis_size_, begin_src_data_src, dst_addr, ptr_quan_arg_i,
|
||||
kDefaultTaskId, thread_num_);
|
||||
} else {
|
||||
code.CodeFunction(last_reducer_, outer_size_, inner_size_, axis_size_, begin_src_data_src, dst_addr,
|
||||
ptr_quan_arg_i, kDefaultTaskId, thread_num_);
|
||||
}
|
||||
begin_src_data_src = dst_addr;
|
||||
}
|
||||
outer_size_ = outer_sizes_.at(i);
|
||||
inner_size_ = inner_sizes_.at(i);
|
||||
axis_size_ = axis_sizes_.at(i);
|
||||
if (!is_last_axis) {
|
||||
code.CodeFunction(reducer_, outer_size_, inner_size_, axis_size_, begin_src_data_src, dst_addr, ptr_quan_arg_i,
|
||||
kDefaultTaskId, thread_num_);
|
||||
} else {
|
||||
code.CodeFunction(last_reducer_, outer_size_, inner_size_, axis_size_, begin_src_data_src, dst_addr,
|
||||
ptr_quan_arg_i, kDefaultTaskId, thread_num_);
|
||||
}
|
||||
begin_src_data_src = dst_addr;
|
||||
}
|
||||
context->AppendCode(code.str());
|
||||
return RET_OK;
|
||||
|
|
|
@ -53,6 +53,9 @@ class ReduceInt8Coder final : public ReduceBaseCoder {
|
|||
int MallocTmpBuffer();
|
||||
int CalculateQuantArgs();
|
||||
void GetQuantArgs(size_t index);
|
||||
int CalReduceMeanQuantParam();
|
||||
int CalReduceProdQuantParam();
|
||||
int CalReduceSumSquareQuantParam();
|
||||
|
||||
private:
|
||||
ReduceQuantArg quant_arg_{0};
|
||||
|
@ -65,6 +68,10 @@ class ReduceInt8Coder final : public ReduceBaseCoder {
|
|||
std::vector<QuantMulArg *> mean_multipliers_;
|
||||
std::vector<QuantMulArg *> prod_multipliers_;
|
||||
std::vector<QuantMulArg *> sum_square_multipliers_;
|
||||
bool axes_hw_pattern_{false}; // the second input tensor is [1 2](axes)
|
||||
int32_t bias_{0};
|
||||
QuantMulArg reduce_mean_quant_param_{}; // used for axes_hw_pattern_
|
||||
int8_t *nchw_in_data_{nullptr};
|
||||
};
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_INT8_REDUCE_INT8_CODER_H_
|
||||
|
|
Loading…
Reference in New Issue