!43356 [Feature] Add dynamic matmul broadcast
Merge pull request !43356 from douzhixing/broadcast
This commit is contained in:
commit
1dea06e0d3
|
@ -25,7 +25,7 @@ void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_
|
|||
* support activation per-layer symmetric && weight per-layer/per-channel symmetric
|
||||
* */
|
||||
for (int r = 0; r < row; r++) {
|
||||
int64_t s2 = a_sums[r] * b_zp_sum;
|
||||
int64_t s2 = a_sums[r];
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r4div = r / C4NUM, r4mod = r % C4NUM;
|
||||
int c16div = c / C16NUM, c16mod = c % C16NUM;
|
||||
|
|
|
@ -17,6 +17,9 @@
|
|||
#include "src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h"
|
||||
#include "nnacl/int8/dynamic_matmul_int8.h"
|
||||
|
||||
using mindspore::lite::kCHWDimNumber;
|
||||
using mindspore::lite::kHWDimNumber;
|
||||
using mindspore::lite::kNCHWDimNumber;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_MEMORY_FAILED;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
@ -105,7 +108,7 @@ void MatmulDynamicBaseInt8CPUKernel::ResizeMatrixBParameter() {
|
|||
for (size_t i = 0; i < w_shape.size() - kSize2; ++i) {
|
||||
batch *= w_shape[i];
|
||||
}
|
||||
param_->batch = batch;
|
||||
b_batch_ = batch;
|
||||
param_->col_ = param_->b_transpose_ ? w_shape[w_shape.size() - kSize2] : w_shape[w_shape.size() - kSize1];
|
||||
param_->deep_ = param_->b_transpose_ ? w_shape[w_shape.size() - kSize1] : w_shape[w_shape.size() - kSize2];
|
||||
|
||||
|
@ -155,7 +158,7 @@ int MatmulDynamicBaseInt8CPUKernel::InitInputQuantParam() {
|
|||
int MatmulDynamicBaseInt8CPUKernel::TransferB() {
|
||||
auto weight_data = reinterpret_cast<int8_t *>(in_tensors_.at(kWeightIndex)->data());
|
||||
CHECK_NULL_RETURN(weight_data);
|
||||
for (int i = 0; i < param_->batch; i++) {
|
||||
for (int i = 0; i < b_batch_; i++) {
|
||||
auto current_weight = weight_data + i * param_->deep_ * param_->col_;
|
||||
auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
|
||||
auto current_sums = weight_sums_ + i * param_->col_align_;
|
||||
|
@ -201,7 +204,7 @@ int MatmulDynamicBaseInt8CPUKernel::InitMatrixBBuffer() {
|
|||
pack_b_ptr_ = nullptr;
|
||||
}
|
||||
pack_b_ptr_ =
|
||||
reinterpret_cast<int8_t *>(malloc(param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t)));
|
||||
reinterpret_cast<int8_t *>(malloc(b_batch_ * param_->col_align_ * param_->deep_align_ * sizeof(int8_t)));
|
||||
if (pack_b_ptr_ == nullptr) {
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
|
@ -210,13 +213,13 @@ int MatmulDynamicBaseInt8CPUKernel::InitMatrixBBuffer() {
|
|||
free(weight_sums_);
|
||||
weight_sums_ = nullptr;
|
||||
}
|
||||
weight_sums_ = reinterpret_cast<int *>(malloc(param_->batch * param_->col_align_ * sizeof(int)));
|
||||
weight_sums_ = reinterpret_cast<int *>(malloc(b_batch_ * param_->col_align_ * sizeof(int)));
|
||||
if (weight_sums_ == nullptr) {
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(pack_b_ptr_, 0, param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t));
|
||||
memset(weight_sums_, 0, param_->batch * param_->col_align_ * sizeof(int));
|
||||
memset(pack_b_ptr_, 0, b_batch_ * param_->col_align_ * param_->deep_align_ * sizeof(int8_t));
|
||||
memset(weight_sums_, 0, b_batch_ * param_->col_align_ * sizeof(int));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -289,27 +292,22 @@ int MatmulDynamicBaseInt8CPUKernel::Prepare() {
|
|||
|
||||
int MatmulDynamicBaseInt8CPUKernel::ReSize() {
|
||||
auto x_shape = in_tensors_.at(0)->shape();
|
||||
auto y_shape = in_tensors_.at(1)->shape();
|
||||
auto o_shape = out_tensors_.at(0)->shape();
|
||||
MS_ASSERT(o_shape.size() >= kSize2);
|
||||
unsigned int i = 0;
|
||||
param_->row_ = param_->a_transpose_ ? x_shape[x_shape.size() - kSize1] : x_shape[x_shape.size() - kSize2];
|
||||
param_->batch = 1;
|
||||
for (; i < x_shape.size() - kSize2; i++) {
|
||||
if (x_shape[i] != y_shape[i]) {
|
||||
break;
|
||||
}
|
||||
param_->batch *= x_shape[i];
|
||||
}
|
||||
for (; i < x_shape.size() - kSize2; i++) {
|
||||
param_->row_ *= x_shape[i];
|
||||
}
|
||||
|
||||
param_->row_ = o_shape[o_shape.size() - kSize2];
|
||||
param_->row_align_ = UP_ROUND(param_->row_, row_tile_);
|
||||
param_->deep_ = param_->a_transpose_ ? x_shape[x_shape.size() - kSize2] : x_shape[x_shape.size() - kSize1];
|
||||
param_->deep_align_ = UP_ROUND(param_->deep_, deep_tile_);
|
||||
|
||||
auto ret = InitMatrixABuffer();
|
||||
auto ret = InitBroadcastParams(in_tensors_[kInputIndex]->shape(), in_tensors_[kWeightIndex]->shape(), param_,
|
||||
&a_offset_, &b_offset_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InitBroadcastParams failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitMatrixABuffer();
|
||||
if (ret != RET_OK) {
|
||||
FreeQuantParam();
|
||||
return ret;
|
||||
|
@ -325,4 +323,74 @@ int MatmulDynamicBaseInt8CPUKernel::ReSize() {
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicBaseInt8CPUKernel::InitBroadcastParams(const std::vector<int> &a_shape_const,
|
||||
const std::vector<int> &b_shape_const, MatMulParameter *params,
|
||||
std::vector<int> *a_offsets, std::vector<int> *b_offsets) {
|
||||
std::vector<int> a_shape = a_shape_const;
|
||||
if (a_shape.size() < kNCHWDimNumber) {
|
||||
size_t add_nums = kNCHWDimNumber - a_shape.size();
|
||||
for (size_t i = 0; i < add_nums; ++i) {
|
||||
(void)a_shape.insert(a_shape.begin(), 1);
|
||||
}
|
||||
}
|
||||
std::vector<int> b_shape = b_shape_const;
|
||||
if (b_shape.size() < kNCHWDimNumber) {
|
||||
size_t add_nums = kNCHWDimNumber - b_shape.size();
|
||||
for (size_t i = 0; i < add_nums; ++i) {
|
||||
(void)b_shape.insert(b_shape.begin(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
int batch_sizes[MAX_SHAPE_SIZE] = {0};
|
||||
int a_batch_sizes[MAX_SHAPE_SIZE] = {0};
|
||||
int b_batch_sizes[MAX_SHAPE_SIZE] = {0};
|
||||
for (int i = a_shape.size() - kCHWDimNumber; i >= 0; --i) {
|
||||
if (static_cast<int>(a_shape.size() - kCHWDimNumber) == i) {
|
||||
batch_sizes[i] = std::max(a_shape[i], b_shape[i]);
|
||||
a_batch_sizes[i] = a_shape[i];
|
||||
b_batch_sizes[i] = b_shape[i];
|
||||
} else {
|
||||
batch_sizes[i] = batch_sizes[i + 1] * std::max(a_shape[i], b_shape[i]);
|
||||
a_batch_sizes[i] = a_batch_sizes[i + 1] * a_shape[i];
|
||||
b_batch_sizes[i] = b_batch_sizes[i + 1] * b_shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
int out_batch = 1;
|
||||
for (size_t i = 0; i < a_shape.size() - kHWDimNumber; ++i) {
|
||||
int max_v = MSMAX(a_shape[i], b_shape[i]);
|
||||
int min_v = MSMIN(a_shape[i], b_shape[i]) > 0 ? MSMIN(a_shape[i], b_shape[i]) : 1;
|
||||
out_batch *= max_v;
|
||||
if (max_v != min_v && max_v % min_v != 0) {
|
||||
MS_LOG(ERROR) << "matmul don't support broadcast for dimension " << a_shape << " and " << b_shape;
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
params->batch = out_batch;
|
||||
|
||||
a_offsets->resize(params->batch, 0);
|
||||
b_offsets->resize(params->batch, 0);
|
||||
for (int i = 0; i < params->batch; ++i) {
|
||||
int64_t delta = i;
|
||||
int a_offset = 0;
|
||||
int b_offset = 0;
|
||||
for (size_t j = 0; j < a_shape.size() - kHWDimNumber; ++j) {
|
||||
if (j > 0) {
|
||||
delta = delta % batch_sizes[j];
|
||||
}
|
||||
if (j < (a_shape.size() - kCHWDimNumber)) {
|
||||
a_offset += (delta / batch_sizes[j + 1] * a_shape[j] / std::max(a_shape[j], b_shape[j])) * a_batch_sizes[j + 1];
|
||||
b_offset += (delta / batch_sizes[j + 1] * b_shape[j] / std::max(a_shape[j], b_shape[j])) * b_batch_sizes[j + 1];
|
||||
} else {
|
||||
a_offset += (delta * a_shape[j] / std::max(a_shape[j], b_shape[j]));
|
||||
b_offset += (delta * b_shape[j] / std::max(a_shape[j], b_shape[j]));
|
||||
}
|
||||
}
|
||||
(*a_offsets)[i] = a_offset;
|
||||
(*b_offsets)[i] = b_offset;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -18,12 +18,14 @@
|
|||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_MATMUL_DYNAMIC_BASE_INT8_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/litert/lite_kernel.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
#include "nnacl/common_func.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
#include "nnacl/int8/common_func_int8.h"
|
||||
#include "src/common/common.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class MatmulDynamicBaseInt8CPUKernel : public LiteKernel {
|
||||
|
@ -36,6 +38,8 @@ class MatmulDynamicBaseInt8CPUKernel : public LiteKernel {
|
|||
~MatmulDynamicBaseInt8CPUKernel() override;
|
||||
int Prepare() override;
|
||||
int ReSize() override;
|
||||
static int InitBroadcastParams(const std::vector<int> &a_shape_const, const std::vector<int> &b_shape_const,
|
||||
MatMulParameter *params, std::vector<int> *a_offsets, std::vector<int> *b_offsets);
|
||||
|
||||
private:
|
||||
void ResizeMatrixBParameter();
|
||||
|
@ -45,6 +49,10 @@ class MatmulDynamicBaseInt8CPUKernel : public LiteKernel {
|
|||
int MallocQuantParam();
|
||||
|
||||
protected:
|
||||
int a_batch_ = 1;
|
||||
int b_batch_ = 1;
|
||||
std::vector<int> a_offset_;
|
||||
std::vector<int> b_offset_;
|
||||
typedef void (*PackFunc)(const int8_t *src, int8_t *dst, int row, int col);
|
||||
virtual void InitParameter() = 0;
|
||||
int TransferA();
|
||||
|
@ -63,6 +71,7 @@ class MatmulDynamicBaseInt8CPUKernel : public LiteKernel {
|
|||
bool filter_per_channel_ = true;
|
||||
int8_t *batch_input_ptr_ = nullptr;
|
||||
int8_t *batch_weight_ptr_ = nullptr;
|
||||
int8_t *batch_a_ptr_ = nullptr;
|
||||
int8_t *batch_b_ptr_ = nullptr;
|
||||
float *batch_c_ptr_ = nullptr;
|
||||
int *input_sums_ = nullptr;
|
||||
|
|
|
@ -54,7 +54,7 @@ int MatmulDynamicInt8CPUKernel::RunImpl(int task_id) {
|
|||
if (filter_per_channel_) {
|
||||
filter_scale += cur_stride;
|
||||
}
|
||||
DynamicMatmul4x16x4AIWI(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, bias_ptr,
|
||||
DynamicMatmul4x16x4AIWI(batch_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, bias_ptr,
|
||||
batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_, param_->deep_align_,
|
||||
param_->col_, quant_param_->input_zp_, quant_param_->input_scale_, filter_scale, filter_zp,
|
||||
filter_per_channel_);
|
||||
|
@ -105,7 +105,7 @@ int MatmulDynamicInt8CPUKernel::Run() {
|
|||
CHECK_NULL_RETURN(c_ptr);
|
||||
for (int i = 0; i < param_->batch; i++) {
|
||||
memset(pack_a_ptr_, quant_param_->input_zp_, param_->row_align_ * param_->deep_align_ * sizeof(int8_t));
|
||||
auto current_src_a = a_ptr + i * param_->row_ * param_->deep_;
|
||||
auto current_src_a = a_ptr + a_offset_[i] * param_->row_ * param_->deep_;
|
||||
if (param_->a_transpose_) {
|
||||
MS_CHECK_TRUE_RET(a_pack_func_ != nullptr, RET_ERROR);
|
||||
a_pack_func_(current_src_a, pack_a_ptr_, param_->deep_, param_->row_);
|
||||
|
@ -114,7 +114,8 @@ int MatmulDynamicInt8CPUKernel::Run() {
|
|||
a_pack_func_(current_src_a, pack_a_ptr_, param_->row_, param_->deep_);
|
||||
}
|
||||
|
||||
batch_b_ptr_ = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
|
||||
batch_a_ptr_ = pack_a_ptr_;
|
||||
batch_b_ptr_ = pack_b_ptr_ + b_offset_[i] * param_->col_align_ * param_->deep_align_;
|
||||
batch_c_ptr_ = c_ptr + i * param_->row_ * param_->col_;
|
||||
|
||||
ret = ParallelLaunch(this->ms_context_, MatmulDynamicInt8Run, this, thread_count_);
|
||||
|
|
|
@ -167,16 +167,16 @@ int MatMulDynamicSdotInt8Kernel::MatMulDynamicRunArm64Sdot() {
|
|||
CHECK_NULL_RETURN(c_ptr);
|
||||
|
||||
for (int i = 0; i < param_->batch; i++) {
|
||||
batch_input_ptr_ = a_ptr + i * param_->row_ * param_->deep_;
|
||||
batch_input_ptr_ = a_ptr + a_offset_[i] * param_->row_ * param_->deep_;
|
||||
auto ret = ParallelLaunch(this->ms_context_, Arm64SdotPreRun, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Arm64SdotPreRun error: [" << ret << "]";
|
||||
return ret;
|
||||
}
|
||||
|
||||
batch_weight_ptr_ = b_ptr + i * param_->col_ * param_->deep_;
|
||||
batch_sums_ = weight_sums_ + i * param_->col_align_;
|
||||
batch_b_ptr_ = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
|
||||
batch_weight_ptr_ = b_ptr + b_offset_[i] * param_->col_ * param_->deep_;
|
||||
batch_sums_ = weight_sums_ + b_offset_[i] * param_->col_align_;
|
||||
batch_b_ptr_ = pack_b_ptr_ + b_offset_[i] * param_->col_align_ * param_->deep_align_;
|
||||
batch_c_ptr_ = c_ptr + i * param_->row_ * param_->col_;
|
||||
|
||||
ret = ParallelLaunch(this->ms_context_, Arm64SdotRun, this, thread_count_);
|
||||
|
|
Loading…
Reference in New Issue