support matmul fp32 pack.

This commit is contained in:
wangpingan2 2023-02-27 14:44:03 +08:00
parent d22f57ae7a
commit 7befa6baa4
18 changed files with 229 additions and 94 deletions

View File

@ -41,6 +41,8 @@ typedef enum MatmulType {
kMatmulInt8Cpu,
kMatmulDynamicInt8Cpu,
kMatmulDynamicSdotInt8Cpu,
kMatmulFp32BaseCpu,
kMatmulFp32Arm64Cpu,
} MatmulType;
typedef struct MatMulParameter {

View File

@ -111,5 +111,10 @@ MatmulFp32BaseCPUKernel *CreateMatmulFp32CPUKernel(OpParameter *parameter, const
return kernel;
}
int MatmulCPUKernel::PreparePackedWeight(const lite::Tensor *tensor) {
matmul_base_->SetWeightIsPacked(true);
return RET_OK;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MatMulFusion, LiteKernelCreator<MatmulCPUKernel>)
} // namespace mindspore::kernel

View File

@ -88,6 +88,9 @@ class MatmulCPUKernel : public LiteKernel {
int ReSize() override;
int Run() override;
int PreparePackedWeight(const lite::Tensor *tensor) override;
MatmulFp32BaseCPUKernel *GetMatmulBase() const { return matmul_base_; }
private:
MatmulFp32BaseCPUKernel *matmul_base_ = nullptr;
};

View File

@ -24,7 +24,9 @@ class MatmulFp32ARM32CPUKernel : public MatmulFp32BaseCPUKernel {
public:
MatmulFp32ARM32CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const mindspore::lite::InnerContext *ctx)
: MatmulFp32BaseCPUKernel(parameter, inputs, outputs, ctx) {}
: MatmulFp32BaseCPUKernel(parameter, inputs, outputs, ctx) {
params_->matmul_type_ = kNotImplemented;
}
~MatmulFp32ARM32CPUKernel() = default;
void InitGlobalVariable() override;

View File

@ -28,7 +28,7 @@ constexpr int64_t kPackAMinUnitNum = 1 << 13;
} // namespace
void MatmulFp32ARM64CPUKernel::InitGlobalVariable() {
matrix_a_.need_pack = true;
matrix_b_.need_pack = true;
matrix_b_.need_pack = !weight_is_packed_;
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2Row12MajorParallel : RowMajor2Col12MajorParallel;
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2Col8MajorParallel : RowMajor2Row8MajorParallel;
pack_opt_ = true;

View File

@ -25,7 +25,9 @@ class MatmulFp32ARM64CPUKernel : public MatmulFp32BaseCPUKernel {
public:
MatmulFp32ARM64CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const mindspore::lite::InnerContext *ctx)
: MatmulFp32BaseCPUKernel(parameter, inputs, outputs, ctx) {}
: MatmulFp32BaseCPUKernel(parameter, inputs, outputs, ctx) {
params_->matmul_type_ = kMatmulFp32Arm64Cpu;
}
~MatmulFp32ARM64CPUKernel() = default;
void InitGlobalVariable() override;

View File

@ -25,7 +25,9 @@ class MatmulFp32AVXCPUKernel : public MatmulFp32BaseCPUKernel {
public:
MatmulFp32AVXCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const mindspore::lite::InnerContext *ctx)
: MatmulFp32BaseCPUKernel(parameter, inputs, outputs, ctx) {}
: MatmulFp32BaseCPUKernel(parameter, inputs, outputs, ctx) {
params_->matmul_type_ = kNotImplemented;
}
~MatmulFp32AVXCPUKernel() = default;
void InitGlobalVariable() override;

View File

@ -32,7 +32,9 @@ class MatmulFp32AVX512CPUKernel : public MatmulFp32BaseCPUKernel {
public:
MatmulFp32AVX512CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const mindspore::lite::InnerContext *ctx)
: MatmulFp32BaseCPUKernel(parameter, inputs, outputs, ctx) {}
: MatmulFp32BaseCPUKernel(parameter, inputs, outputs, ctx) {
params_->matmul_type_ = kNotImplemented;
}
~MatmulFp32AVX512CPUKernel() = default;
void InitGlobalVariable() override;

View File

@ -60,6 +60,9 @@ MatmulFp32BaseCPUKernel::~MatmulFp32BaseCPUKernel() {
}
}
if (params_->b_const_) {
if (!matrix_b_.need_pack && weight_is_packed_) {
return;
}
if (is_sharing_pack_) {
lite::PackWeightManager::GetInstance()->Free(matrix_b_.pack_ptr);
} else {
@ -70,7 +73,7 @@ MatmulFp32BaseCPUKernel::~MatmulFp32BaseCPUKernel() {
void MatmulFp32BaseCPUKernel::InitGlobalVariable() {
matrix_a_.need_pack = true;
matrix_b_.need_pack = true;
matrix_b_.need_pack = !weight_is_packed_;
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2Row12MajorParallel : RowMajor2Col12MajorParallel;
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2Col8MajorParallel : RowMajor2Row8MajorParallel;
row_tile_ = C12NUM;
@ -239,6 +242,10 @@ int MatmulFp32BaseCPUKernel::PackMatrixB() {
reinterpret_cast<float *>(ms_context_->allocator->Malloc(matrix_b_.pack_size * sizeof(float)));
}
} else {
if (!matrix_b_.need_pack && weight_is_packed_) {
matrix_b_.pack_ptr = reinterpret_cast<float *>(in_tensors_[SECOND_INPUT]->data());
return RET_OK;
}
bool is_packed = false;
void *data = nullptr;
if (is_sharing_pack_) {

View File

@ -39,6 +39,7 @@ class MatmulFp32BaseCPUKernel : public LiteKernel {
const std::vector<lite::Tensor *> &outputs, const mindspore::lite::InnerContext *ctx)
: LiteKernel(parameter, inputs, outputs, ctx) {
params_ = reinterpret_cast<MatMulParameter *>(op_parameter_);
params_->matmul_type_ = kMatmulFp32BaseCpu;
}
~MatmulFp32BaseCPUKernel() override;
int Prepare() override;
@ -72,6 +73,10 @@ class MatmulFp32BaseCPUKernel : public LiteKernel {
virtual int ParallelRunByRow1Deep1GEPDOT(int task_id) const { return RET_ERROR; }
virtual int GetThreadCuttingPolicy();
const float *GetPackBPtr() const { return matrix_b_.pack_ptr; }
const int GetBBatch() const { return b_batch_; }
void SetWeightIsPacked(bool weight_is_packed) { this->weight_is_packed_ = weight_is_packed; }
public:
struct MatrixInfo {
bool need_pack{false};
@ -140,6 +145,7 @@ class MatmulFp32BaseCPUKernel : public LiteKernel {
float *conv1x1_origin_weight_ = nullptr;
float *conv1x1_origin_bias_ = nullptr;
bool is_sharing_pack_ = true;
bool weight_is_packed_{false};
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_MATMUL_FP32_BASE_H_

View File

@ -25,7 +25,9 @@ class MatmulFp32SSECPUKernel : public MatmulFp32BaseCPUKernel {
public:
MatmulFp32SSECPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const mindspore::lite::InnerContext *ctx)
: MatmulFp32BaseCPUKernel(parameter, inputs, outputs, ctx) {}
: MatmulFp32BaseCPUKernel(parameter, inputs, outputs, ctx) {
params_->matmul_type_ = kNotImplemented;
}
~MatmulFp32SSECPUKernel() = default;
void InitGlobalVariable() override;

View File

@ -433,4 +433,10 @@ int MatmulDynamicBaseInt8CPUKernel::InitBroadcastParams(const std::vector<int> &
return RET_OK;
}
int MatmulDynamicBaseInt8CPUKernel::PreparePackedWeight(const lite::Tensor *tensor) {
weight_is_packed_ = true;
weight_sums_tensor_ = tensor;
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -45,8 +45,7 @@ class MatmulDynamicBaseInt8CPUKernel : public LiteKernel {
const int8_t *GetPackBPtr() const { return pack_b_ptr_; }
const int *GetWeightSums() const { return weight_sums_; }
const int GetBBatch() const { return b_batch_; }
void SetWeightIsPacked(bool weight_is_packed) { this->weight_is_packed_ = weight_is_packed; }
void SetWeightSumsTensor(lite::Tensor *weight_sums_tensor) { this->weight_sums_tensor_ = weight_sums_tensor; }
int PreparePackedWeight(const lite::Tensor *tensor) override;
private:
void ResizeMatrixBParameter();
@ -97,7 +96,7 @@ class MatmulDynamicBaseInt8CPUKernel : public LiteKernel {
bool enable_fp16_ = false;
PackFunc b_pack_func_ = nullptr;
bool weight_is_packed_ = false;
lite::Tensor *weight_sums_tensor_ = nullptr;
const lite::Tensor *weight_sums_tensor_ = nullptr;
};
} // namespace mindspore::kernel

View File

@ -182,6 +182,8 @@ class MS_API LiteKernel : public Abstractkernel {
}
bool ws_allocated_ = false;
virtual int PreparePackedWeight(const lite::Tensor *tensor) { return mindspore::lite::RET_OK; }
protected:
virtual int UpdateThreadNumProcess(int32_t kernel_type, int64_t per_unit_load_num, int64_t per_unit_store_num,
int64_t unit_num);

View File

@ -15,7 +15,7 @@
*/
#include "src/litert/runtime_packed_node_pass.h"
#include "nnacl/op_base.h"
#include "src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h"
#include "nnacl/matmul_parameter.h"
using RecoveryWeightFunc = void (*)(void *, void *, int, int, bool);
namespace mindspore {
@ -75,9 +75,9 @@ void PackedNodePass::Run(Model *model, const std::vector<Tensor *> &tensors) {
MS_LOG(ERROR) << "Custom attr error.";
return;
}
auto val_offset = schema::CreateMatMulFusion(
fbb, std::atoi(attr_map[kTransposeA].c_str()), std::atoi(attr_map[kTransposeB].c_str()),
static_cast<schema::ActivationType>(std::atoi(attr_map[kActivationType].c_str())));
auto val_offset =
schema::CreateMatMulFusion(fbb, std::stoi(attr_map[kTransposeA]), std::stoi(attr_map[kTransposeB]),
static_cast<schema::ActivationType>(std::stoi(attr_map[kActivationType])));
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_MatMulFusion, val_offset.o);
fbb.Finish(prim_offset);
void *prim = malloc(fbb.GetSize());
@ -96,21 +96,23 @@ void PackedNodePass::Run(Model *model, const std::vector<Tensor *> &tensors) {
}
node->primitive_ = custom_primitive;
pack_info->is_packed_ = true;
pack_info->weight_sums_index_ = node->input_indices_.back();
pack_info->b_batch_ = std::atoi(attr_map["b_batch"].c_str());
pack_info->col_ = std::atoi(attr_map["col"].c_str());
pack_info->deep_ = std::atoi(attr_map["deep"].c_str());
pack_info->col_align_ = std::atoi(attr_map["col_align"].c_str());
pack_info->deep_align_ = std::atoi(attr_map["deep_align"].c_str());
pack_info->b_transpose_ = std::atoi(attr_map[kTransposeB].c_str());
pack_info->b_batch_ = std::stoi(attr_map["b_batch"]);
pack_info->col_ = std::stoi(attr_map["col"]);
pack_info->deep_ = std::stoi(attr_map["deep"]);
pack_info->col_align_ = std::stoi(attr_map["col_align"]);
pack_info->deep_align_ = std::stoi(attr_map["deep_align"]);
pack_info->b_transpose_ = std::stoi(attr_map[kTransposeB]);
pack_info->cpu_option_ = attr_map["cpu_option"];
AddNodePackInfo(node->name_, pack_info);
node->input_indices_.pop_back();
node->node_type_ = schema::PrimitiveType_MatMulFusion;
}
if (node->quant_type_ == schema::QuantType_QUANT_DYNAMIC) {
pack_info->weight_sums_index_ = node->input_indices_.back();
node->input_indices_.pop_back();
if (!(reinterpret_cast<lite::LiteModel *>(model)->keep_model_buf())) {
CopyWeightBiasSumsTensor(tensors);
}
}
if (!(reinterpret_cast<lite::LiteModel *>(model)->keep_model_buf())) {
CopyWeightBiasSumsTensor(tensors);
node->node_type_ = schema::PrimitiveType_MatMulFusion;
}
}
@ -180,12 +182,69 @@ void MatmulDynamicSdotInt8Cpu(void *src, void *dst, int row, int col, bool trans
}
}
void MatmulFp32BaseCpu(void *src, void *dst, int row, int col, bool transpose) {
if (!transpose) {
// RowMajor2Row8MajorParallel
auto src_r = static_cast<float *>(src);
auto dst_r = static_cast<float *>(dst);
for (int r = 0; r < row; r++) {
float *src_c = src_r + r * col;
int c = 0;
for (; c < col; c++) {
int cd8 = c / C8NUM;
int cm8 = c % C8NUM;
src_c[c] = dst_r[cd8 * C8NUM * row + r * C8NUM + cm8];
}
}
return;
}
// RowMajor2Col8MajorParallel
auto src_r = static_cast<float *>(src);
auto dst_r = static_cast<float *>(dst);
int row8 = row / C8NUM * C8NUM;
int col_skip = col / C4NUM * C4NUM;
int skip_size = C4NUM;
int ri = 0;
for (; ri < row8; ri += C8NUM) {
int ci = 0;
for (; ci < col_skip; ci += skip_size) {
float *src_c = src_r + ci;
float *dst_c = dst_r + ci * C8NUM;
for (int tr = 0; tr < C8NUM; tr++) {
for (int tc = 0; tc < C4NUM; tc++) {
src_c[tr * col + tc] = dst_c[tc * C8NUM + tr];
}
}
}
for (; ci < col; ci++) {
float *src_c = src_r + ci;
float *dst_c = dst_r + ci * C8NUM;
for (int i = 0; i < C8NUM; i++) {
src_c[i * col] = dst_c[i];
}
}
src_r += C8NUM * col;
dst_r += C8NUM * col;
}
for (; ri < row; ri++, src_r += col, dst_r++) {
for (int i = 0; i < col; i++) {
src_r[i] = dst_r[i * C8NUM];
}
}
}
RecoveryWeightFunc GetRecoveryWeightFunc(const int quant_type, const TypeId data_type, const int node_type,
const std::string &cpu_option) {
if (cpu_option == kArm64SimdDot && node_type == schema::PrimitiveType_MatMulFusion &&
quant_type == schema::QuantType_QUANT_DYNAMIC && data_type == kNumberTypeInt8) {
return MatmulDynamicSdotInt8Cpu;
}
if (cpu_option == kArm64SimdDot && node_type == schema::PrimitiveType_MatMulFusion &&
data_type == kNumberTypeFloat32) {
return MatmulFp32BaseCpu;
}
return nullptr;
}
@ -200,23 +259,26 @@ int PackedMatmulKernelExec(kernel::KernelExec *kernel_exec, const std::vector<Te
auto kernel = kernel_exec->kernel();
MS_CHECK_TRUE_MSG(kernel != nullptr, lite::RET_NULL_PTR, "kernel is nullptr.");
auto param = reinterpret_cast<MatMulParameter *>(kernel_exec->op_parameter());
if (dst_tensor->data_type() != kNumberTypeInt8 || kernel->quant_type() != schema::QuantType_QUANT_DYNAMIC) {
if (dst_tensor->data_type() == kNumberTypeFloat32) {
if (param->matmul_type_ == kNotImplemented) {
return RecoveryPackedWeight(dst_tensor, static_cast<int>(kernel->quant_type()), dst_tensor->data_type(),
schema::PrimitiveType_MatMulFusion, pack_info);
}
}
if (dst_tensor->data_type() == kNumberTypeInt8 && param->matmul_type_ != kMatmulDynamicSdotInt8Cpu &&
pack_info->cpu_option_ == kArm64SimdDot) {
return RecoveryPackedWeight(dst_tensor, static_cast<int>(kernel->quant_type()), dst_tensor->data_type(),
schema::PrimitiveType_MatMulFusion, pack_info);
}
if (param->matmul_type_ != kMatmulDynamicSdotInt8Cpu && pack_info->cpu_option_ == kArm64SimdDot) {
return RecoveryPackedWeight(dst_tensor, static_cast<int>(kernel->quant_type()), dst_tensor->data_type(),
schema::PrimitiveType_MatMulFusion, pack_info);
}
auto matmul_kernel = static_cast<kernel::MatmulDynamicBaseInt8CPUKernel *>(kernel);
matmul_kernel->SetWeightIsPacked(true);
auto lite_kernel = static_cast<kernel::LiteKernel *>(kernel);
lite::Tensor *weight_sums = nullptr;
auto index = static_cast<size_t>(pack_info->weight_sums_index_);
if (index < tensors.size()) {
matmul_kernel->SetWeightSumsTensor(tensors.at(index));
weight_sums = tensors.at(index);
}
return lite::RET_OK;
return lite_kernel->PreparePackedWeight(weight_sums);
}
int RecoveryPackedWeight(Tensor *weight, const int quant_type, const TypeId data_type, const int node_type,
@ -239,6 +301,10 @@ int RecoveryPackedWeight(Tensor *weight, const int quant_type, const TypeId data
current_weight = static_cast<void *>(static_cast<int8_t *>(unpack_data) + i * pack_info->deep_ * pack_info->col_);
current_b_pack =
static_cast<void *>(static_cast<int8_t *>(pack_b_ptr) + i * pack_info->col_align_ * pack_info->deep_align_);
} else if (weight->data_type() == kNumberTypeFloat32) {
current_weight = static_cast<void *>(static_cast<float *>(unpack_data) + i * pack_info->deep_ * pack_info->col_);
current_b_pack =
static_cast<void *>(static_cast<float *>(pack_b_ptr) + i * pack_info->col_align_ * pack_info->deep_);
} else {
free(unpack_data);
MS_LOG(ERROR) << "unsupported data type.";

View File

@ -28,13 +28,13 @@ namespace mindspore {
namespace lite {
struct PackInfo {
bool is_packed_{false};
int weight_sums_index_;
int weight_sums_index_{-1};
int b_batch_;
int deep_;
int col_;
int deep_align_;
int col_align_;
bool b_transpose_;
bool b_transpose_{false};
std::string cpu_option_;
};

View File

@ -21,6 +21,7 @@
#include "tools/converter/offline_packing_optimizer.h"
#include "src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h"
#include "mindspore/core/ops/op_name.h"
#include "src/litert/kernel/cpu/fp32/matmul_fp32.h"
namespace mindspore {
namespace {
@ -37,6 +38,27 @@ void AddCustomAttr(std::vector<std::unique_ptr<mindspore::schema::AttributeT>> *
attrs->emplace_back(std::move(attr));
}
int AddWeightSumsToInputs(const mindspore::kernel::MatmulDynamicBaseInt8CPUKernel *matmul_kernel,
schema::MetaGraphT *meta_graph, const std::unique_ptr<schema::CNodeT> &cnode,
size_t weight_sum_size) {
auto weight_sums_tensor = std::make_unique<schema::TensorT>();
weight_sums_tensor->nodeType = lite::NodeType_ValueNode;
weight_sums_tensor->format = schema::Format_NHWC;
weight_sums_tensor->dataType = TypeId::kNumberTypeInt32;
weight_sums_tensor->dims = {};
weight_sums_tensor->dims.emplace_back(weight_sum_size / sizeof(int));
weight_sums_tensor->data.resize(weight_sum_size);
weight_sums_tensor->name = cnode->name + "_weight_sums";
if (memcpy_s(weight_sums_tensor->data.data(), weight_sums_tensor->data.size(), matmul_kernel->GetWeightSums(),
weight_sum_size) != EOK) {
MS_LOG(ERROR) << "new CustomT error.";
return RET_ERROR;
}
cnode->inputIndex.emplace_back(meta_graph->allTensors.size());
meta_graph->allTensors.emplace_back(std::move(weight_sums_tensor));
return RET_OK;
}
int ReplaceMatMulFusionToCustom(schema::MetaGraphT *meta_graph, const std::unique_ptr<schema::CNodeT> &cnode,
const std::unique_ptr<mindspore::schema::TensorT> &b_input,
const std::string &cpu_option) {
@ -51,65 +73,75 @@ int ReplaceMatMulFusionToCustom(schema::MetaGraphT *meta_graph, const std::uniqu
return RET_ERROR;
}
auto matmul_param = reinterpret_cast<MatMulParameter *>(param);
if (matmul_param->matmul_type_ == kNotImplemented) {
MS_LOG(ERROR) << "Unsupported matmul type, only support fp32 and dynamic quant int8.";
return RET_ERROR;
}
cnode->primitive->value.type = schema::PrimitiveType_Custom;
auto primitive = new (std::nothrow) schema::CustomT;
if (primitive == nullptr) {
MS_LOG(ERROR) << "new CustomT error.";
return RET_NULL_PTR;
}
primitive->type = kMatmulCustomType;
// activation_type
AddCustomAttr(&(primitive->attr), ops::kActivationType, std::to_string(matmul_param->act_type_));
// transpose_a
AddCustomAttr(&(primitive->attr), ops::kTransposeA, std::to_string(matmul_param->a_transpose_));
// transpose_b
AddCustomAttr(&(primitive->attr), ops::kTransposeB, std::to_string(matmul_param->b_transpose_));
int b_batch;
const void *pack_b_ptr = nullptr;
size_t pack_b_size;
if (matmul_param->matmul_type_ == kMatmulDynamicSdotInt8Cpu) {
cnode->primitive->value.type = schema::PrimitiveType_Custom;
auto primitive = new (std::nothrow) schema::CustomT;
if (primitive == nullptr) {
MS_LOG(ERROR) << "new CustomT error.";
return RET_NULL_PTR;
}
primitive->type = kMatmulCustomType;
// activation_type
AddCustomAttr(&(primitive->attr), ops::kActivationType, std::to_string(matmul_param->act_type_));
// transpose_a
AddCustomAttr(&(primitive->attr), ops::kTransposeA, std::to_string(matmul_param->a_transpose_));
// transpose_b
AddCustomAttr(&(primitive->attr), ops::kTransposeB, std::to_string(matmul_param->b_transpose_));
// replace packed data
auto matmul_kernel = reinterpret_cast<const mindspore::kernel::MatmulDynamicBaseInt8CPUKernel *>(lite_kernel);
auto b_batch = matmul_kernel->GetBBatch();
auto pack_b_size = b_batch * matmul_param->col_align_ * matmul_param->deep_align_ * sizeof(int8_t);
b_input->data.resize(pack_b_size);
if (memcpy_s(b_input->data.data(), b_input->data.size(), matmul_kernel->GetPackBPtr(), pack_b_size) != EOK) {
delete primitive;
MS_LOG(ERROR) << "new CustomT error.";
return RET_ERROR;
}
// add weight_sums to inputs
b_batch = matmul_kernel->GetBBatch();
pack_b_size = b_batch * matmul_param->col_align_ * matmul_param->deep_align_ * sizeof(int8_t);
pack_b_ptr = reinterpret_cast<const void *>(matmul_kernel->GetPackBPtr());
auto weight_sum_size = b_batch * matmul_param->col_align_ * sizeof(int);
auto weight_sums_tensor = std::make_unique<schema::TensorT>();
weight_sums_tensor->nodeType = lite::NodeType_ValueNode;
weight_sums_tensor->format = schema::Format_NHWC;
weight_sums_tensor->dataType = TypeId::kNumberTypeInt32;
weight_sums_tensor->dims = {};
weight_sums_tensor->dims.emplace_back(weight_sum_size / sizeof(int));
weight_sums_tensor->data.resize(weight_sum_size);
weight_sums_tensor->name = cnode->name + "_weight_sums";
if (memcpy_s(weight_sums_tensor->data.data(), weight_sums_tensor->data.size(), matmul_kernel->GetWeightSums(),
weight_sum_size) != EOK) {
int ret = AddWeightSumsToInputs(matmul_kernel, meta_graph, cnode, weight_sum_size);
if (ret != RET_OK) {
delete primitive;
MS_LOG(ERROR) << "new CustomT error.";
return RET_ERROR;
MS_LOG(ERROR) << "add weight sums to inputs error.";
return ret;
}
cnode->inputIndex.emplace_back(meta_graph->allTensors.size());
meta_graph->allTensors.emplace_back(std::move(weight_sums_tensor));
// add scalar to attr
AddCustomAttr(&(primitive->attr), "b_batch", std::to_string(b_batch));
AddCustomAttr(&(primitive->attr), "deep", std::to_string(matmul_param->deep_));
AddCustomAttr(&(primitive->attr), "col", std::to_string(matmul_param->col_));
AddCustomAttr(&(primitive->attr), "col_align", std::to_string(matmul_param->col_align_));
AddCustomAttr(&(primitive->attr), "deep_align", std::to_string(matmul_param->deep_align_));
// add cpu option
std::string cpu_option_str = cpu_option;
AddCustomAttr(&(primitive->attr), "cpu_option", std::move(cpu_option_str));
cnode->primitive->value.value = primitive;
} else if (matmul_param->matmul_type_ == kMatmulFp32BaseCpu || matmul_param->matmul_type_ == kMatmulFp32Arm64Cpu) {
auto matmul_kernel = reinterpret_cast<const mindspore::kernel::MatmulCPUKernel *>(lite_kernel);
auto matmul_kernel_base = matmul_kernel->GetMatmulBase();
b_batch = matmul_kernel_base->GetBBatch();
pack_b_size = b_batch * matmul_param->col_align_ * matmul_param->deep_ * sizeof(float);
pack_b_ptr = reinterpret_cast<const void *>(matmul_kernel_base->GetPackBPtr());
}
if (pack_b_ptr == nullptr) {
delete primitive;
MS_LOG(ERROR) << "pack_b_ptr is nullptr.";
return RET_NULL_PTR;
}
// copy packed weight to meta graph
b_input->data.resize(pack_b_size);
if (memcpy_s(b_input->data.data(), b_input->data.size(), pack_b_ptr, pack_b_size) != EOK) {
delete primitive;
MS_LOG(ERROR) << "memcpy packed weight error.";
return RET_ERROR;
}
// add scalar to attr
AddCustomAttr(&(primitive->attr), "b_batch", std::to_string(b_batch));
AddCustomAttr(&(primitive->attr), "deep", std::to_string(matmul_param->deep_));
AddCustomAttr(&(primitive->attr), "col", std::to_string(matmul_param->col_));
AddCustomAttr(&(primitive->attr), "col_align", std::to_string(matmul_param->col_align_));
AddCustomAttr(&(primitive->attr), "deep_align", std::to_string(matmul_param->deep_align_));
// add cpu option
std::string cpu_option_str = cpu_option;
AddCustomAttr(&(primitive->attr), "cpu_option", std::move(cpu_option_str));
cnode->primitive->value.value = primitive;
return RET_OK;
}

View File

@ -40,6 +40,7 @@ const char kAndroidArmCpuBackendOption[] = "ANDROID_ARM_CPU";
mindspore::lite::InnerContext *InitInnerContextForAndroidArmCpu() {
// if the operation use thread_pool in inner context will throw exception.
auto inner_context = new (std::nothrow) lite::InnerContext();
inner_context->Init();
MS_CHECK_TRUE_MSG(inner_context != nullptr, nullptr, "Create InnerContext failed.");
inner_context->thread_num_ = kSingleThread;
inner_context->instructions_ctx_.support_sdot = true;
@ -234,10 +235,6 @@ STATUS MatmulPacking(const mindspore::CNodePtr &cnode_ptr, const FuncGraphPtr &f
}
op_parameter->thread_num_ = kSingleThread;
op_parameter->quant_type_ = GetQuantType(cnode_ptr);
if (op_parameter->quant_type_ != schema::QuantType::QuantType_QUANT_DYNAMIC) {
MS_LOG(DEBUG) << "Only do pack for dynamic quant matmul operation now, skip " << cnode_ptr->fullname_with_scope();
return RET_OK;
}
(void)snprintf(op_parameter->name_, cnode_ptr->fullname_with_scope().length() + 1, "%s",
cnode_ptr->fullname_with_scope().c_str());