!31648 fix magic number in matmul_sparse_fp32

Merge pull request !31648 from hangq/master
This commit is contained in:
i-robot 2022-03-21 11:26:26 +00:00 committed by Gitee
commit 683597658a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 30 additions and 14 deletions

View File

@ -38,21 +38,34 @@ void MatmulSparseCPUKernel::InitParameter() {
params_->b_const_ = false; params_->b_const_ = false;
auto a_shape = in_tensors_.at(0)->shape(); auto a_shape = in_tensors_.at(0)->shape();
int a_batch = 1; int a_batch = 1;
for (size_t i = 0; i < a_shape.size() - 2; ++i) { constexpr size_t batch_matmul_split = -2;
for (size_t i = 0; i < a_shape.size() + batch_matmul_axis; ++i) {
a_batch *= a_shape[i]; a_batch *= a_shape[i];
} }
params_->batch = a_batch; params_->batch = a_batch;
params_->row_ = params_->a_transpose_ ? a_shape[a_shape.size() - 1] : a_shape[a_shape.size() - 2]; constexpr size_t left_row_axis_transpose = -1;
params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1]; constexpr size_t left_row_axis_not_transpose = -2;
constexpr size_t left_col_axis_transpose = -2;
constexpr size_t left_col_axis_not_transpose = -1;
params_->row_ = params_->a_transpose_ ? (a_shape[a_shape.size() + left_row_axis_transpose])
: (a_shape[a_shape.size() + left_row_axis_not_transpose]);
params_->deep_ = params_->a_transpose_ ? (a_shape[a_shape.size() + left_col_axis_transpose])
: (a_shape[a_shape.size() + left_col_axis_not_transpose]);
auto b_shape = in_tensors_.at(1)->shape(); auto b_shape = in_tensors_.at(1)->shape();
int b_batch = 1; int b_batch = 1;
for (size_t i = 0; i < b_shape.size() - 2; ++i) { for (size_t i = 0; i < b_shape.size() + batch_matmul_split; ++i) {
b_batch *= b_shape[i]; b_batch *= b_shape[i];
} }
MS_ASSERT(a_batch == b_batch); MS_ASSERT(a_batch == b_batch);
params_->col_ = params_->b_transpose_ ? b_shape[b_shape.size() - 2] : b_shape[b_shape.size() - 1]; constexpr size_t right_row_axis_transpose = -2;
params_->deep_ = params_->b_transpose_ ? b_shape[b_shape.size() - 1] : b_shape[b_shape.size() - 2]; constexpr size_t right_row_axis_not_transpose = -1;
constexpr size_t right_col_axis_transpose = -1;
constexpr size_t right_col_axis_not_transpose = -2;
params_->col_ = params_->b_transpose_ ? (b_shape[b_shape.size() + right_row_axis_transpose])
: (b_shape[b_shape.size() + right_row_axis_not_transpose]);
params_->deep_ = params_->b_transpose_ ? (b_shape[b_shape.size() + right_col_axis_transpose])
: (b_shape[b_shape.size() + right_col_axis_not_transpose]);
params_->row_align_ = UP_ROUND(params_->row_, C8NUM); params_->row_align_ = UP_ROUND(params_->row_, C8NUM);
params_->col_align_ = UP_ROUND(params_->col_, C8NUM); params_->col_align_ = UP_ROUND(params_->col_, C8NUM);
@ -61,10 +74,11 @@ void MatmulSparseCPUKernel::InitParameter() {
if (params_->a_transpose_) { if (params_->a_transpose_) {
return; return;
} }
constexpr int perm_1 = 2;
auto area = params_->row_ * params_->deep_; auto area = params_->row_ * params_->deep_;
trans_param_.num_axes_ = 3; trans_param_.num_axes_ = kNumIntThree;
trans_param_.perm_[0] = 0; trans_param_.perm_[0] = 0;
trans_param_.perm_[1] = 2; trans_param_.perm_[1] = perm_1;
trans_param_.perm_[2] = 1; trans_param_.perm_[2] = 1;
trans_param_.strides_[2] = 1; trans_param_.strides_[2] = 1;
trans_param_.strides_[1] = params_->deep_; trans_param_.strides_[1] = params_->deep_;
@ -77,6 +91,7 @@ void MatmulSparseCPUKernel::InitParameter() {
namespace { namespace {
constexpr float kFpPrecision = 1e-6; constexpr float kFpPrecision = 1e-6;
constexpr size_t kBlockSize = 8; constexpr size_t kBlockSize = 8;
constexpr size_t bias_tensor_index = 2;
} // namespace } // namespace
int kernel::MatmulSparseCPUKernel::PrepareWeight() { int kernel::MatmulSparseCPUKernel::PrepareWeight() {
@ -104,7 +119,7 @@ int kernel::MatmulSparseCPUKernel::PrepareWeight() {
auto cur_data = weight_data[i * params_->col_ + j]; auto cur_data = weight_data[i * params_->col_ + j];
if (cur_data > kFpPrecision) { if (cur_data > kFpPrecision) {
sparsity_weight_->data[weight_data_index++] = cur_data; sparsity_weight_->data[weight_data_index++] = cur_data;
sparsity_weight_->act_stride[act_stride_index++] = i * 8 * sizeof(float); sparsity_weight_->act_stride[act_stride_index++] = i * kBlockSize * sizeof(float);
(*(sparsity_weight_->non_zero_num + j))++; (*(sparsity_weight_->non_zero_num + j))++;
} }
} }
@ -113,8 +128,9 @@ int kernel::MatmulSparseCPUKernel::PrepareWeight() {
} }
int MatmulSparseCPUKernel::PrepareBias() { int MatmulSparseCPUKernel::PrepareBias() {
if (in_tensors_.size() == 3) { constexpr size_t has_bias_tensor_num = 3;
auto bias_tensor = in_tensors_[2]; if (in_tensors_.size() == has_bias_tensor_num) {
auto bias_tensor = in_tensors_[bias_tensor_index];
if (bias_tensor->ElementsNum() != params_->col_) { if (bias_tensor->ElementsNum() != params_->col_) {
MS_LOG(ERROR) << "Not support broadcast bias data now"; MS_LOG(ERROR) << "Not support broadcast bias data now";
return lite::RET_NOT_SUPPORT; return lite::RET_NOT_SUPPORT;
@ -231,7 +247,7 @@ int MatmulSparseCPUKernel::RunInstrinsics() {
printf("\r\n"); printf("\r\n");
} }
{ {
auto bias = reinterpret_cast<float *>(in_tensors_.at(2)->data()); auto bias = reinterpret_cast<float *>(in_tensors_.at(bias_tensor_index)->data());
printf("=========================================bias:\r\n"); printf("=========================================bias:\r\n");
for (size_t i = 0; i < params_->col_; i++) { for (size_t i = 0; i < params_->col_; i++) {
printf(" %2.2f", bias[i]); printf(" %2.2f", bias[i]);
@ -260,7 +276,7 @@ int MatmulSparseCPUKernel::RunInstrinsics() {
printf("\r\n"); printf("\r\n");
} }
#endif #endif
auto bias = reinterpret_cast<float *>(in_tensors_.at(2)->data()); auto bias = reinterpret_cast<float *>(in_tensors_.at(bias_tensor_index)->data());
auto output = reinterpret_cast<float *>(out_tensors_.front()->data()); auto output = reinterpret_cast<float *>(out_tensors_.front()->data());
for (int i = 0; i < params_->row_align_ / kBlockSize; i++) { for (int i = 0; i < params_->row_align_ / kBlockSize; i++) {
MatMulSparse8x8(a_pack_ + i * kBlockSize * params_->deep_, sparsity_weight_->data, sparsity_weight_->non_zero_num, MatMulSparse8x8(a_pack_ + i * kBlockSize * params_->deep_, sparsity_weight_->data, sparsity_weight_->non_zero_num,
@ -295,7 +311,7 @@ int MatmulSparseCPUKernel::Run() {
return ret; return ret;
} }
CHECK_NULL_RETURN(sparsity_weight_); CHECK_NULL_RETURN(sparsity_weight_);
auto bias = reinterpret_cast<float *>(in_tensors_.at(2)->data()); auto bias = reinterpret_cast<float *>(in_tensors_.at(bias_tensor_index)->data());
auto output = reinterpret_cast<float *>(out_tensors_.front()->data()); auto output = reinterpret_cast<float *>(out_tensors_.front()->data());
for (int i = 0; i < params_->row_align_ / kBlockSize; i++) { for (int i = 0; i < params_->row_align_ / kBlockSize; i++) {
SPMM8x8Fp32(a_pack_ + i * params_->deep_ * kBlockSize, sparsity_weight_->data, sparsity_weight_->non_zero_num, SPMM8x8Fp32(a_pack_ + i * params_->deep_ * kBlockSize, sparsity_weight_->data, sparsity_weight_->non_zero_num,