!4635 Fix aborted bug of softmax op

Merge pull request !4635 from wangminggui/master
This commit is contained in:
mindspore-ci-bot 2020-08-18 10:09:58 +08:00 committed by Gitee
commit 21502810dd
5 changed files with 17 additions and 10 deletions

View File

@ -15,6 +15,7 @@
*/
#include "src/runtime/kernel/arm/int8/softmax_int8.h"
#include <limits>
#include "src/runtime/kernel/arm/nnacl/int8/softmax_int8.h"
#include "schema/model_generated.h"
#include "src/runtime/runtime_api.h"
@ -44,6 +45,8 @@ int SoftmaxInt8CPUKernel::Init() {
auto out_quant_args = out_tensor->GetQuantParams();
quant_params_.out_quant_arg_.scale_ = out_quant_args.front().scale;
quant_params_.out_quant_arg_.zp_ = out_quant_args.front().zeroPoint;
quant_params_.output_activation_min_ = std::numeric_limits<int8_t>::min();
quant_params_.output_activation_max_ = std::numeric_limits<int8_t>::max();
if (!InferShapeDone()) {
return RET_OK;
@ -95,12 +98,10 @@ int SoftmaxInt8CPUKernel::DoSoftmax(int task_id) {
int stride = UP_DIV(outter_size, thread_count_);
int count = MSMIN(stride, outter_size - stride * task_id);
int stride_size = stride * task_id * inner_size;
input_ptr += stride * task_id * inner_size;
output_ptr += stride * task_id * inner_size;
exp_data_ += stride * task_id * inner_size;
auto error_code = Int8Softmax(input_ptr, output_ptr, count, exp_data_, sum_data_, quant_params_, softmax_param_);
auto error_code = SoftmaxInt8(input_ptr + stride_size, output_ptr + stride_size, count, exp_data_ + stride_size,
sum_data_, quant_params_, softmax_param_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "DoSoftmax error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;

View File

@ -37,8 +37,8 @@ class SoftmaxInt8CPUKernel : public SoftmaxBaseCPUKernel {
private:
void FreeTmpBuffer();
float *sum_data_;
float *exp_data_;
float *sum_data_ = nullptr;
float *exp_data_ = nullptr;
SoftmaxQuantArg quant_params_;
};
} // namespace mindspore::kernel

View File

@ -17,7 +17,7 @@
#include "nnacl/int8/softmax_int8.h"
#include <math.h>
int Int8Softmax(const int8_t *input_ptr, int8_t *output_ptr, int count, float *exp_data, float *sum_data,
int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, float *exp_data, float *sum_data,
SoftmaxQuantArg quant_param, SoftmaxParameter *parameter) {
int32_t axis = parameter->axis_;
int n_dim = parameter->n_dim_;
@ -48,7 +48,8 @@ int Int8Softmax(const int8_t *input_ptr, int8_t *output_ptr, int count, float *e
int inner_offset = axis_offset + i;
float real_output = exp_data[inner_offset] / sum_data[i];
int32_t output_scaled = round(real_output / output_scale) + output_zp;
output_ptr[inner_offset] = MSMAX(CHAR_MIN, MSMIN(CHAR_MAX, output_scaled));
output_ptr[inner_offset] =
MSMAX(quant_param.output_activation_min_, MSMIN(quant_param.output_activation_max_, output_scaled));
}
}
}

View File

@ -24,7 +24,7 @@
#ifdef __cplusplus
extern "C" {
#endif
int Int8Softmax(const int8_t *input_ptr, int8_t *output_ptr, int count, float *exp_data, float *sum_data,
int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, float *exp_data, float *sum_data,
SoftmaxQuantArg quant_param, SoftmaxParameter *parameter);
#ifdef __cplusplus
}

View File

@ -169,6 +169,11 @@ typedef struct SplitQuantArg {
typedef struct SoftmaxQuantArg {
QuantArg in_quant_args_;
QuantArg out_quant_arg_;
int output_activation_min_;
int output_activation_max_;
int output_multiplier_;
int shift_left_;
int shift_right_;
} SoftmaxQuantArg;
typedef struct ReshapeQuantArg {