Add matmul broadcast

This commit is contained in:
duzhixing 2022-10-08 16:47:43 +08:00
parent ca49c7577c
commit d055fa3337
5 changed files with 106 additions and 28 deletions

View File

@ -25,7 +25,7 @@ void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_
* support activation per-layer symmetric && weight per-layer/per-channel symmetric * support activation per-layer symmetric && weight per-layer/per-channel symmetric
* */ * */
for (int r = 0; r < row; r++) { for (int r = 0; r < row; r++) {
int64_t s2 = a_sums[r] * b_zp_sum; int64_t s2 = a_sums[r];
for (int c = 0; c < col; c++) { for (int c = 0; c < col; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM; int r4div = r / C4NUM, r4mod = r % C4NUM;
int c16div = c / C16NUM, c16mod = c % C16NUM; int c16div = c / C16NUM, c16mod = c % C16NUM;

View File

@ -17,6 +17,9 @@
#include "src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h" #include "src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h"
#include "nnacl/int8/dynamic_matmul_int8.h" #include "nnacl/int8/dynamic_matmul_int8.h"
using mindspore::lite::kCHWDimNumber;
using mindspore::lite::kHWDimNumber;
using mindspore::lite::kNCHWDimNumber;
using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_MEMORY_FAILED; using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK; using mindspore::lite::RET_OK;
@ -105,7 +108,7 @@ void MatmulDynamicBaseInt8CPUKernel::ResizeMatrixBParameter() {
for (size_t i = 0; i < w_shape.size() - kSize2; ++i) { for (size_t i = 0; i < w_shape.size() - kSize2; ++i) {
batch *= w_shape[i]; batch *= w_shape[i];
} }
param_->batch = batch; b_batch_ = batch;
param_->col_ = param_->b_transpose_ ? w_shape[w_shape.size() - kSize2] : w_shape[w_shape.size() - kSize1]; param_->col_ = param_->b_transpose_ ? w_shape[w_shape.size() - kSize2] : w_shape[w_shape.size() - kSize1];
param_->deep_ = param_->b_transpose_ ? w_shape[w_shape.size() - kSize1] : w_shape[w_shape.size() - kSize2]; param_->deep_ = param_->b_transpose_ ? w_shape[w_shape.size() - kSize1] : w_shape[w_shape.size() - kSize2];
@ -155,7 +158,7 @@ int MatmulDynamicBaseInt8CPUKernel::InitInputQuantParam() {
int MatmulDynamicBaseInt8CPUKernel::TransferB() { int MatmulDynamicBaseInt8CPUKernel::TransferB() {
auto weight_data = reinterpret_cast<int8_t *>(in_tensors_.at(kWeightIndex)->data()); auto weight_data = reinterpret_cast<int8_t *>(in_tensors_.at(kWeightIndex)->data());
CHECK_NULL_RETURN(weight_data); CHECK_NULL_RETURN(weight_data);
for (int i = 0; i < param_->batch; i++) { for (int i = 0; i < b_batch_; i++) {
auto current_weight = weight_data + i * param_->deep_ * param_->col_; auto current_weight = weight_data + i * param_->deep_ * param_->col_;
auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_; auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
auto current_sums = weight_sums_ + i * param_->col_align_; auto current_sums = weight_sums_ + i * param_->col_align_;
@ -201,7 +204,7 @@ int MatmulDynamicBaseInt8CPUKernel::InitMatrixBBuffer() {
pack_b_ptr_ = nullptr; pack_b_ptr_ = nullptr;
} }
pack_b_ptr_ = pack_b_ptr_ =
reinterpret_cast<int8_t *>(malloc(param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t))); reinterpret_cast<int8_t *>(malloc(b_batch_ * param_->col_align_ * param_->deep_align_ * sizeof(int8_t)));
if (pack_b_ptr_ == nullptr) { if (pack_b_ptr_ == nullptr) {
FreeTmpBuffer(); FreeTmpBuffer();
return RET_ERROR; return RET_ERROR;
@ -210,13 +213,13 @@ int MatmulDynamicBaseInt8CPUKernel::InitMatrixBBuffer() {
free(weight_sums_); free(weight_sums_);
weight_sums_ = nullptr; weight_sums_ = nullptr;
} }
weight_sums_ = reinterpret_cast<int *>(malloc(param_->batch * param_->col_align_ * sizeof(int))); weight_sums_ = reinterpret_cast<int *>(malloc(b_batch_ * param_->col_align_ * sizeof(int)));
if (weight_sums_ == nullptr) { if (weight_sums_ == nullptr) {
FreeTmpBuffer(); FreeTmpBuffer();
return RET_ERROR; return RET_ERROR;
} }
memset(pack_b_ptr_, 0, param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t)); memset(pack_b_ptr_, 0, b_batch_ * param_->col_align_ * param_->deep_align_ * sizeof(int8_t));
memset(weight_sums_, 0, param_->batch * param_->col_align_ * sizeof(int)); memset(weight_sums_, 0, b_batch_ * param_->col_align_ * sizeof(int));
return RET_OK; return RET_OK;
} }
@ -289,27 +292,22 @@ int MatmulDynamicBaseInt8CPUKernel::Prepare() {
int MatmulDynamicBaseInt8CPUKernel::ReSize() { int MatmulDynamicBaseInt8CPUKernel::ReSize() {
auto x_shape = in_tensors_.at(0)->shape(); auto x_shape = in_tensors_.at(0)->shape();
auto y_shape = in_tensors_.at(1)->shape();
auto o_shape = out_tensors_.at(0)->shape(); auto o_shape = out_tensors_.at(0)->shape();
MS_ASSERT(o_shape.size() >= kSize2); MS_ASSERT(o_shape.size() >= kSize2);
unsigned int i = 0;
param_->row_ = param_->a_transpose_ ? x_shape[x_shape.size() - kSize1] : x_shape[x_shape.size() - kSize2];
param_->batch = 1;
for (; i < x_shape.size() - kSize2; i++) {
if (x_shape[i] != y_shape[i]) {
break;
}
param_->batch *= x_shape[i];
}
for (; i < x_shape.size() - kSize2; i++) {
param_->row_ *= x_shape[i];
}
param_->row_ = o_shape[o_shape.size() - kSize2];
param_->row_align_ = UP_ROUND(param_->row_, row_tile_); param_->row_align_ = UP_ROUND(param_->row_, row_tile_);
param_->deep_ = param_->a_transpose_ ? x_shape[x_shape.size() - kSize2] : x_shape[x_shape.size() - kSize1]; param_->deep_ = param_->a_transpose_ ? x_shape[x_shape.size() - kSize2] : x_shape[x_shape.size() - kSize1];
param_->deep_align_ = UP_ROUND(param_->deep_, deep_tile_); param_->deep_align_ = UP_ROUND(param_->deep_, deep_tile_);
auto ret = InitMatrixABuffer(); auto ret = InitBroadcastParams(in_tensors_[kInputIndex]->shape(), in_tensors_[kWeightIndex]->shape(), param_,
&a_offset_, &b_offset_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitBroadcastParams failed.";
return RET_ERROR;
}
ret = InitMatrixABuffer();
if (ret != RET_OK) { if (ret != RET_OK) {
FreeQuantParam(); FreeQuantParam();
return ret; return ret;
@ -325,4 +323,74 @@ int MatmulDynamicBaseInt8CPUKernel::ReSize() {
} }
return RET_OK; return RET_OK;
} }
int MatmulDynamicBaseInt8CPUKernel::InitBroadcastParams(const std::vector<int> &a_shape_const,
const std::vector<int> &b_shape_const, MatMulParameter *params,
std::vector<int> *a_offsets, std::vector<int> *b_offsets) {
std::vector<int> a_shape = a_shape_const;
if (a_shape.size() < kNCHWDimNumber) {
size_t add_nums = kNCHWDimNumber - a_shape.size();
for (size_t i = 0; i < add_nums; ++i) {
(void)a_shape.insert(a_shape.begin(), 1);
}
}
std::vector<int> b_shape = b_shape_const;
if (b_shape.size() < kNCHWDimNumber) {
size_t add_nums = kNCHWDimNumber - b_shape.size();
for (size_t i = 0; i < add_nums; ++i) {
(void)b_shape.insert(b_shape.begin(), 1);
}
}
int batch_sizes[MAX_SHAPE_SIZE] = {0};
int a_batch_sizes[MAX_SHAPE_SIZE] = {0};
int b_batch_sizes[MAX_SHAPE_SIZE] = {0};
for (int i = a_shape.size() - kCHWDimNumber; i >= 0; --i) {
if (static_cast<int>(a_shape.size() - kCHWDimNumber) == i) {
batch_sizes[i] = std::max(a_shape[i], b_shape[i]);
a_batch_sizes[i] = a_shape[i];
b_batch_sizes[i] = b_shape[i];
} else {
batch_sizes[i] = batch_sizes[i + 1] * std::max(a_shape[i], b_shape[i]);
a_batch_sizes[i] = a_batch_sizes[i + 1] * a_shape[i];
b_batch_sizes[i] = b_batch_sizes[i + 1] * b_shape[i];
}
}
int out_batch = 1;
for (size_t i = 0; i < a_shape.size() - kHWDimNumber; ++i) {
int max_v = MSMAX(a_shape[i], b_shape[i]);
int min_v = MSMIN(a_shape[i], b_shape[i]) > 0 ? MSMIN(a_shape[i], b_shape[i]) : 1;
out_batch *= max_v;
if (max_v != min_v && max_v % min_v != 0) {
MS_LOG(ERROR) << "matmul don't support broadcast for dimension " << a_shape << " and " << b_shape;
return RET_ERROR;
}
}
params->batch = out_batch;
a_offsets->resize(params->batch, 0);
b_offsets->resize(params->batch, 0);
for (int i = 0; i < params->batch; ++i) {
int64_t delta = i;
int a_offset = 0;
int b_offset = 0;
for (size_t j = 0; j < a_shape.size() - kHWDimNumber; ++j) {
if (j > 0) {
delta = delta % batch_sizes[j];
}
if (j < (a_shape.size() - kCHWDimNumber)) {
a_offset += (delta / batch_sizes[j + 1] * a_shape[j] / std::max(a_shape[j], b_shape[j])) * a_batch_sizes[j + 1];
b_offset += (delta / batch_sizes[j + 1] * b_shape[j] / std::max(a_shape[j], b_shape[j])) * b_batch_sizes[j + 1];
} else {
a_offset += (delta * a_shape[j] / std::max(a_shape[j], b_shape[j]));
b_offset += (delta * b_shape[j] / std::max(a_shape[j], b_shape[j]));
}
}
(*a_offsets)[i] = a_offset;
(*b_offsets)[i] = b_offset;
}
return RET_OK;
}
} // namespace mindspore::kernel } // namespace mindspore::kernel

View File

@ -18,12 +18,14 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_MATMUL_DYNAMIC_BASE_INT8_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_MATMUL_DYNAMIC_BASE_INT8_H_
#include <vector> #include <vector>
#include <algorithm>
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/litert/lite_kernel.h" #include "src/litert/lite_kernel.h"
#include "nnacl/matmul_parameter.h" #include "nnacl/matmul_parameter.h"
#include "nnacl/common_func.h" #include "nnacl/common_func.h"
#include "nnacl/int8/quantize.h" #include "nnacl/int8/quantize.h"
#include "nnacl/int8/common_func_int8.h" #include "nnacl/int8/common_func_int8.h"
#include "src/common/common.h"
namespace mindspore::kernel { namespace mindspore::kernel {
class MatmulDynamicBaseInt8CPUKernel : public LiteKernel { class MatmulDynamicBaseInt8CPUKernel : public LiteKernel {
@ -36,6 +38,8 @@ class MatmulDynamicBaseInt8CPUKernel : public LiteKernel {
~MatmulDynamicBaseInt8CPUKernel() override; ~MatmulDynamicBaseInt8CPUKernel() override;
int Prepare() override; int Prepare() override;
int ReSize() override; int ReSize() override;
static int InitBroadcastParams(const std::vector<int> &a_shape_const, const std::vector<int> &b_shape_const,
MatMulParameter *params, std::vector<int> *a_offsets, std::vector<int> *b_offsets);
private: private:
void ResizeMatrixBParameter(); void ResizeMatrixBParameter();
@ -45,6 +49,10 @@ class MatmulDynamicBaseInt8CPUKernel : public LiteKernel {
int MallocQuantParam(); int MallocQuantParam();
protected: protected:
int a_batch_ = 1;
int b_batch_ = 1;
std::vector<int> a_offset_;
std::vector<int> b_offset_;
typedef void (*PackFunc)(const int8_t *src, int8_t *dst, int row, int col); typedef void (*PackFunc)(const int8_t *src, int8_t *dst, int row, int col);
virtual void InitParameter() = 0; virtual void InitParameter() = 0;
int TransferA(); int TransferA();
@ -63,6 +71,7 @@ class MatmulDynamicBaseInt8CPUKernel : public LiteKernel {
bool filter_per_channel_ = true; bool filter_per_channel_ = true;
int8_t *batch_input_ptr_ = nullptr; int8_t *batch_input_ptr_ = nullptr;
int8_t *batch_weight_ptr_ = nullptr; int8_t *batch_weight_ptr_ = nullptr;
int8_t *batch_a_ptr_ = nullptr;
int8_t *batch_b_ptr_ = nullptr; int8_t *batch_b_ptr_ = nullptr;
float *batch_c_ptr_ = nullptr; float *batch_c_ptr_ = nullptr;
int *input_sums_ = nullptr; int *input_sums_ = nullptr;

View File

@ -54,7 +54,7 @@ int MatmulDynamicInt8CPUKernel::RunImpl(int task_id) {
if (filter_per_channel_) { if (filter_per_channel_) {
filter_scale += cur_stride; filter_scale += cur_stride;
} }
DynamicMatmul4x16x4AIWI(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, bias_ptr, DynamicMatmul4x16x4AIWI(batch_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, bias_ptr,
batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_, param_->deep_align_, batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_, param_->deep_align_,
param_->col_, quant_param_->input_zp_, quant_param_->input_scale_, filter_scale, filter_zp, param_->col_, quant_param_->input_zp_, quant_param_->input_scale_, filter_scale, filter_zp,
filter_per_channel_); filter_per_channel_);
@ -105,7 +105,7 @@ int MatmulDynamicInt8CPUKernel::Run() {
CHECK_NULL_RETURN(c_ptr); CHECK_NULL_RETURN(c_ptr);
for (int i = 0; i < param_->batch; i++) { for (int i = 0; i < param_->batch; i++) {
memset(pack_a_ptr_, quant_param_->input_zp_, param_->row_align_ * param_->deep_align_ * sizeof(int8_t)); memset(pack_a_ptr_, quant_param_->input_zp_, param_->row_align_ * param_->deep_align_ * sizeof(int8_t));
auto current_src_a = a_ptr + i * param_->row_ * param_->deep_; auto current_src_a = a_ptr + a_offset_[i] * param_->row_ * param_->deep_;
if (param_->a_transpose_) { if (param_->a_transpose_) {
MS_CHECK_TRUE_RET(a_pack_func_ != nullptr, RET_ERROR); MS_CHECK_TRUE_RET(a_pack_func_ != nullptr, RET_ERROR);
a_pack_func_(current_src_a, pack_a_ptr_, param_->deep_, param_->row_); a_pack_func_(current_src_a, pack_a_ptr_, param_->deep_, param_->row_);
@ -114,7 +114,8 @@ int MatmulDynamicInt8CPUKernel::Run() {
a_pack_func_(current_src_a, pack_a_ptr_, param_->row_, param_->deep_); a_pack_func_(current_src_a, pack_a_ptr_, param_->row_, param_->deep_);
} }
batch_b_ptr_ = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_; batch_a_ptr_ = pack_a_ptr_;
batch_b_ptr_ = pack_b_ptr_ + b_offset_[i] * param_->col_align_ * param_->deep_align_;
batch_c_ptr_ = c_ptr + i * param_->row_ * param_->col_; batch_c_ptr_ = c_ptr + i * param_->row_ * param_->col_;
ret = ParallelLaunch(this->ms_context_, MatmulDynamicInt8Run, this, thread_count_); ret = ParallelLaunch(this->ms_context_, MatmulDynamicInt8Run, this, thread_count_);

View File

@ -167,16 +167,16 @@ int MatMulDynamicSdotInt8Kernel::MatMulDynamicRunArm64Sdot() {
CHECK_NULL_RETURN(c_ptr); CHECK_NULL_RETURN(c_ptr);
for (int i = 0; i < param_->batch; i++) { for (int i = 0; i < param_->batch; i++) {
batch_input_ptr_ = a_ptr + i * param_->row_ * param_->deep_; batch_input_ptr_ = a_ptr + a_offset_[i] * param_->row_ * param_->deep_;
auto ret = ParallelLaunch(this->ms_context_, Arm64SdotPreRun, this, op_parameter_->thread_num_); auto ret = ParallelLaunch(this->ms_context_, Arm64SdotPreRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Arm64SdotPreRun error: [" << ret << "]"; MS_LOG(ERROR) << "Arm64SdotPreRun error: [" << ret << "]";
return ret; return ret;
} }
batch_weight_ptr_ = b_ptr + i * param_->col_ * param_->deep_; batch_weight_ptr_ = b_ptr + b_offset_[i] * param_->col_ * param_->deep_;
batch_sums_ = weight_sums_ + i * param_->col_align_; batch_sums_ = weight_sums_ + b_offset_[i] * param_->col_align_;
batch_b_ptr_ = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_; batch_b_ptr_ = pack_b_ptr_ + b_offset_[i] * param_->col_align_ * param_->deep_align_;
batch_c_ptr_ = c_ptr + i * param_->row_ * param_->col_; batch_c_ptr_ = c_ptr + i * param_->row_ * param_->col_;
ret = ParallelLaunch(this->ms_context_, Arm64SdotRun, this, thread_count_); ret = ParallelLaunch(this->ms_context_, Arm64SdotRun, this, thread_count_);