fix pdr_pdr_adf

This commit is contained in:
z00512249 2021-04-12 11:09:38 +08:00
parent 99a5dacdc7
commit dd51f4e67f
17 changed files with 676 additions and 298 deletions

View File

@ -109,6 +109,7 @@ endif()
string(REPLACE "/mindspore/lite" "" TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(CORE_DIR ${TOP_DIR}/mindspore/core)
set(CCSRC_DIR ${TOP_DIR}/mindspore/ccsrc)
set(NNACL_DIR ${CCSRC_DIR}/backend/kernel_compiler/cpu/nnacl)
include_directories(${TOP_DIR})
include_directories(${CORE_DIR})
include_directories(${CORE_DIR}/ir)

View File

@ -6,7 +6,7 @@ option(PLATFORM_ARM32 "build operator for android arm 32" off)
option(PLATFORM_ARM64 "build operator for android arm 64" off)
string(REPLACE "/mindspore/lite/micro" "" TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(NNACL_DIR ${TOP_DIR}/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl)
include_directories(${CMAKE_BINARY_DIR})
include(${TOP_DIR}/cmake/utils.cmake)
include(${TOP_DIR}/cmake/dependency_utils.cmake)

View File

@ -177,7 +177,6 @@ set(LITE_SRC
${LITE_DIR}/tools/common/flag_parser.cc
)
set(NNACL_DIR ${TOP_DIR}/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl)
set(LITE_KERNEL_SRC
### nnacl
${NNACL_DIR}/common_func.c
@ -186,6 +185,7 @@ set(LITE_KERNEL_SRC
${NNACL_DIR}/base/slice_base.c
${NNACL_DIR}/fp32/winograd_utils.c
${NNACL_DIR}/fp32/pack_fp32.c
${NNACL_DIR}/fp32/arithmetic_fp32.c
${NNACL_DIR}/int8/quantize.c
${NNACL_DIR}/int8/pack_int8.c
${NNACL_DIR}/int8/matmul_int8.c

View File

@ -6,6 +6,7 @@ set(WRAPPER_SRC
${WRAPPER_DIR}/base/detection_post_process_base_wrapper.c
${WRAPPER_DIR}/base/optimize_handler_wrapper.c
${WRAPPER_DIR}/fp32/matmul_fp32_wrapper.c
${WRAPPER_DIR}/fp32/arithmetic_fp32_wrapper.c
${WRAPPER_DIR}/int8/matmul_int8_wrapper.c
${WRAPPER_DIR}/int8/add_int8_wrapper.c
${WRAPPER_DIR}/int8/concat_int8_wrapper.c

View File

@ -18,7 +18,7 @@ include_directories(${3RD_DIR}/flatbuffers/include)
#include ms
include_directories(${TOP_DIR}/)
include_directories(${TOP_DIR}/mindspore/core/)
include_directories(${TOP_DIR}/mindspore/ccsrc/backend/kernel_compiler/cpu)
include_directories(${NNACL_DIR}/../)
include_directories(${LITE_DIR})
include_directories(${MICRO_DIR})
#include coder

View File

@ -15,8 +15,6 @@
*/
#include "coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.h"
#include <string>
#include <map>
#include <type_traits>
#include "coder/opcoders/file_collector.h"
#include "nnacl/fp32/arithmetic_fp32.h"
#include "coder/opcoders/parallel.h"
@ -24,229 +22,76 @@
namespace mindspore::lite::micro::nnacl {
int ArithmeticFP32Coder::Init(CoderContext *const context) {
filter_tensor_ = input_tensors_.at(kWeightIndex);
MS_CHECK_PTR(filter_tensor_);
if (input_tensor_->data_type() == kNumberTypeFloat32 || input_tensor_->data_type() == kNumberTypeFloat16) {
data_type_ = kDataTypeFloat;
namespace {
std::string wrap_void(const std::string &a) { return "(void *)(" + a + ")"; }
std::string wrap_uint8(const std::string &a) { return "(uint8_t *)(" + a + ")"; }
std::string wrap_offset(const std::string &a, int offset) { return "(" + a + "+" + std::to_string(offset) + ")"; }
} // namespace
void ArithmeticFP32Coder::InitRunFunction(int primitive_type) {
ARITHMETIC_FUNC_INFO_FP32 fun_table[] = {
{PrimitiveType_MulFusion, schema::ActivationType_RELU, "ElementMulRelu", "ElementMulReluInt", "",
"ElementOptMulRelu", "ElementOptMulReluInt"},
{PrimitiveType_MulFusion, schema::ActivationType_RELU6, "ElementMulRelu6", "ElementMulRelu6Int", "",
"ElementOptMulRelu6", "ElementOptMulRelu6Int"},
{PrimitiveType_MulFusion, schema::ActivationType_NO_ACTIVATION, "ElementMul", "ElementMulInt", "", "ElementOptMul",
"ElementOptMulInt"},
{PrimitiveType_AddFusion, schema::ActivationType_RELU, "ElementAddRelu", "", "", "ElementOptAddRelu", ""},
{PrimitiveType_AddFusion, schema::ActivationType_RELU6, "ElementAddRelu6", "", "", "ElementOptAddRelu6", ""},
{PrimitiveType_AddFusion, schema::ActivationType_NO_ACTIVATION, "ElementAdd", "ElementAddInt", "", "ElementOptAdd",
"ElementOptAddInt"},
{PrimitiveType_SubFusion, schema::ActivationType_RELU, "ElementSubRelu", "", "", "ElementOptSubRelu", ""},
{PrimitiveType_SubFusion, schema::ActivationType_RELU6, "ElementSubRelu6", "", "", "ElementOptSubRelu6", ""},
{PrimitiveType_SubFusion, schema::ActivationType_NO_ACTIVATION, "ElementSub", "ElementSubInt", "", "ElementOptSub",
"ElementOptSubInt"},
{PrimitiveType_DivFusion, schema::ActivationType_RELU, "ElementDivRelu", "", "", "ElementOptDivRelu", ""},
{PrimitiveType_DivFusion, schema::ActivationType_RELU6, "ElementDivRelu6", "", "", "ElementOptDivRelu6", ""},
{PrimitiveType_DivFusion, schema::ActivationType_NO_ACTIVATION, "ElementDiv", "", "", "ElementOptDiv",
"ElementOptDivInt"},
{PrimitiveType_RealDiv, schema::ActivationType_RELU, "ElementDivRelu", "", "", "ElementOptDivRelu", ""},
{PrimitiveType_RealDiv, schema::ActivationType_RELU6, "ElementDivRelu6", "", "", "ElementOptDivRelu6", ""},
{PrimitiveType_RealDiv, schema::ActivationType_NO_ACTIVATION, "ElementDiv", "", "", "ElementOptDiv",
"ElementOptDivInt"},
{PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, "ElementLogicalAnd", "ElementLogicalAndInt",
"ElementLogicalAndBool", "", ""},
{PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, "ElementLogicalOr", "", "ElementLogicalOrBool", "",
""},
{PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, "ElementMaximum", "ElementMaximumInt", "", "", ""},
{PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, "ElementMinimum", "ElementMinimumInt", "", "", ""},
{PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, "ElementFloorMod", "ElementFloorModInt", "", "", ""},
{PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, "ElementFloorDiv", "ElementFloorDivInt", "", "", ""},
{PrimitiveType_Mod, schema::ActivationType_NO_ACTIVATION, "ElementMod", "ElementModInt", "", "ElementOptMod",
"ElementOptModInt"},
{PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, "ElementSquaredDifference", "", "", "",
""}};
size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_FUNC_INFO_FP32);
for (size_t i = 0; i < length; i++) {
if (fun_table[i].primitive_type_ == primitive_type &&
fun_table[i].activation_type_ == arithmetic_parameter_->activation_type_) {
arithmetic_run_ = fun_table[i].func_;
arithmetic_run_int_ = fun_table[i].int_func_;
arithmetic_run_bool_ = fun_table[i].bool_func_;
arithmetic_opt_run_ = fun_table[i].opt_func_;
arithmetic_opt_run_int_ = fun_table[i].opt_int_func_;
return;
}
}
TypeId input_type_id = input_tensor_->data_type();
data_type_len_ = lite::DataTypeSize(input_tensor_->data_type());
if (input_type_id == kNumberTypeFloat32 || input_type_id == kNumberTypeFloat) {
arithmetic_func_type_ = kArithmeticFuncFloat;
} else if (input_type_id == kNumberTypeBool) {
arithmetic_func_type_ = kArithmeticFuncBool;
} else if (input_type_id == kNumberTypeInt || input_type_id == kNumberTypeInt32) {
arithmetic_func_type_ = kArithmeticFuncInt;
} else {
data_type_ = kDataTypeInt;
arithmetic_func_type_ = kArithmeticFuncUnknow;
}
arithmetic_parameter_->in_elements_num0_ = input_tensor_->ElementsNum();
arithmetic_parameter_->in_elements_num1_ = filter_tensor_->ElementsNum();
arithmetic_parameter_->out_elements_num_ = output_tensor_->ElementsNum();
for (size_t i = 0; i < input_tensor_->shape().size(); i++) {
if (arithmetic_parameter_->in_shape0_[i] == -1) {
MS_CHECK_RET_CODE(
memcpy_s(arithmetic_parameter_->in_shape0_, DEFAULT_ARITHMETIC_NDIMS * sizeof(int),
static_cast<void *>(input_tensor_->shape().data()), input_tensor_->shape().size() * sizeof(int)),
"memcpy_s in shape0 failed!");
}
}
for (size_t i = 0; i < filter_tensor_->shape().size(); i++) {
if (arithmetic_parameter_->in_shape1_[i] == -1) {
MS_CHECK_RET_CODE(
memcpy_s(arithmetic_parameter_->in_shape1_, DEFAULT_ARITHMETIC_NDIMS * sizeof(int),
static_cast<void *>(filter_tensor_->shape().data()), filter_tensor_->shape().size() * sizeof(int)),
"memcpy_s in shape1 failed!");
}
}
for (size_t i = 0; i < output_tensor_->shape().size(); i++) {
if (arithmetic_parameter_->out_shape_[i] == -1) {
MS_CHECK_RET_CODE(
memcpy_s(arithmetic_parameter_->out_shape_, DEFAULT_ARITHMETIC_NDIMS * sizeof(int),
static_cast<void *>(output_tensor_->shape().data()), output_tensor_->shape().size() * sizeof(int)),
"memcpy_s in out shape failed!");
}
}
if (arithmetic_parameter_->in_elements_num0_ == 1 || arithmetic_parameter_->in_elements_num1_ == 1) {
switch (arithmetic_parameter_->op_parameter_.type_) {
case PrimitiveType_MulFusion:
switch (arithmetic_parameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_parameter_->broadcasting_ = false;
arithmetic_opt_run_ = "ElementOptMulRelu";
arithmetic_opt_run_int_ = "ElementOptMulReluInt";
break;
case schema::ActivationType_RELU6:
arithmetic_parameter_->broadcasting_ = false;
arithmetic_opt_run_ = "ElementOptMulRelu6";
arithmetic_opt_run_int_ = "ElementOptMulRelu6Int";
break;
default:
arithmetic_parameter_->broadcasting_ = false;
arithmetic_opt_run_ = "ElementOptMul";
arithmetic_opt_run_int_ = "ElementOptMulInt";
break;
}
break;
case PrimitiveType_AddFusion:
switch (arithmetic_parameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_parameter_->broadcasting_ = false;
arithmetic_opt_run_ = "ElementOptAddRelu";
arithmetic_opt_run_int_ = "ElementOptAddReluInt";
break;
case schema::ActivationType_RELU6:
arithmetic_parameter_->broadcasting_ = false;
arithmetic_opt_run_ = "ElementOptAddRelu6";
arithmetic_opt_run_int_ = "ElementOptAddRelu6Int";
break;
default:
arithmetic_parameter_->broadcasting_ = false;
arithmetic_opt_run_ = "ElementOptAdd";
arithmetic_opt_run_int_ = "ElementOptAddInt";
break;
}
break;
case PrimitiveType_SubFusion:
switch (arithmetic_parameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_parameter_->broadcasting_ = false;
arithmetic_opt_run_ = "ElementOptSubRelu";
break;
case schema::ActivationType_RELU6:
arithmetic_parameter_->broadcasting_ = false;
arithmetic_opt_run_ = "ElementOptSubRelu6";
break;
default:
arithmetic_parameter_->broadcasting_ = false;
arithmetic_opt_run_ = "ElementOptSub";
break;
}
break;
default:
break;
}
}
return RET_OK;
}
int ArithmeticFP32Coder::BroadcastRun(const std::string &input0, const std::string &input1, const std::string &output,
int dim, int out_count, int out_thread_stride, NNaclFp32Serializer *const code) {
if (dim > break_pos_) {
if (data_type_ == kDataTypeInt) {
*code << "\t\t" << arithmetic_run_int_ << "(((" << input0 << ") + " << out_thread_stride << "), ((" << input1
<< ") + " << out_thread_stride << "), ((" << output << ") + " << out_thread_stride << "), " << out_count
<< ");\n";
} else {
*code << "\t\t" << arithmetic_run_ << "(((" << input0 << ") + " << out_thread_stride << "), ((" << input1
<< ") + " << out_thread_stride << "), ((" << output << ") + " << out_thread_stride << "), " << out_count
<< ");\n";
}
return RET_OK;
}
for (int i = 0; i < arithmetic_parameter_->out_shape_[dim]; ++i) {
int pos0_ = arithmetic_parameter_->in_shape0_[dim] == 1 ? 0 : i;
int pos1_ = arithmetic_parameter_->in_shape1_[dim] == 1 ? 0 : i;
int error_code = BroadcastRun(input0 + "+" + std::to_string(pos0_ * arithmetic_parameter_->in_strides0_[dim]),
input1 + "+" + std::to_string(pos1_ * arithmetic_parameter_->in_strides1_[dim]),
output + "+" + std::to_string(i * arithmetic_parameter_->out_strides_[dim]), dim + 1,
out_count, out_thread_stride, code);
if (error_code != RET_OK) {
return error_code;
}
}
return RET_OK;
}
int ArithmeticFP32Coder::Prepare(CoderContext *const context) {
if (parameter_ == nullptr) {
return RET_ERROR;
}
arithmetic_parameter_ = reinterpret_cast<ArithmeticParameter *>(parameter_);
std::map<int, std::function<void()>> type_setters = {
{PrimitiveType_MulFusion,
[this]() {
switch (arithmetic_parameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = "ElementMulRelu";
arithmetic_run_int_ = "ElementMulReluInt";
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = "ElementMulRelu6";
arithmetic_run_int_ = "ElementMulRelu6Int";
break;
default:
arithmetic_run_ = "ElementMul";
arithmetic_run_int_ = "ElementMulInt";
break;
}
}},
{PrimitiveType_AddFusion,
[this]() {
switch (arithmetic_parameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = "ElementAddRelu";
arithmetic_run_int_ = "ElementAddReluInt";
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = "ElementAddRelu6";
arithmetic_run_int_ = "ElementAddRelu6Int";
break;
default:
arithmetic_run_ = "ElementAdd";
arithmetic_run_int_ = "ElementAddInt";
break;
}
}},
{PrimitiveType_SubFusion,
[this]() {
switch (arithmetic_parameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = "ElementSubRelu";
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = "ElementSubRelu6";
break;
default:
arithmetic_run_ = "ElementSub";
break;
}
}},
{PrimitiveType_DivFusion,
[this]() {
switch (arithmetic_parameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = "ElementDivRelu";
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = "ElementDivRelu6";
break;
default:
arithmetic_run_ = "ElementDiv";
break;
}
}},
{PrimitiveType_LogicalAnd, [this]() { arithmetic_run_ = "ElementLogicalAnd"; }},
{PrimitiveType_LogicalOr, [this]() { arithmetic_run_ = "ElementLogicalOr"; }},
{PrimitiveType_Maximum, [this]() { arithmetic_run_ = "ElementMaximum"; }},
{PrimitiveType_Minimum, [this]() { arithmetic_run_ = "ElementMinimum"; }},
{PrimitiveType_FloorDiv, [this]() { arithmetic_run_ = "ElementFloorDiv"; }},
{PrimitiveType_FloorMod, [this]() { arithmetic_run_ = "ElementFloorMod"; }},
{PrimitiveType_Equal, [this]() { arithmetic_run_ = "ElementEqual"; }},
{PrimitiveType_NotEqual, [this]() { arithmetic_run_ = "ElementNotEqual"; }},
{PrimitiveType_Less, [this]() { arithmetic_run_ = "ElementLess"; }},
{PrimitiveType_LessEqual, [this]() { arithmetic_run_ = "ElementLessEqual"; }},
{PrimitiveType_Greater, [this]() { arithmetic_run_ = "ElementGreater"; }},
{PrimitiveType_GreaterEqual, [this]() { arithmetic_run_ = "ElementGreaterEqual"; }},
{PrimitiveType_SquaredDifference, [this]() { arithmetic_run_ = "ElementSquaredDifference"; }},
};
auto iter = type_setters.find(parameter_->type_);
if (iter != type_setters.end()) {
iter->second();
} else {
MS_LOG(ERROR) << "Error Operator type " << parameter_;
arithmetic_run_ = "NULL";
return RET_ERROR;
}
MS_CHECK_RET_CODE(Init(context), "do arothmetic code failed!");
return RET_OK;
}
void ArithmeticFP32Coder::ComputeInOutStrides() {
int ArithmeticFP32Coder::ReSize(CoderContext *const context) {
CalcMultiplesAndStrides(arithmetic_parameter_);
if (arithmetic_parameter_->broadcasting_) {
outside_ = 1;
for (auto i = arithmetic_parameter_->ndim_ - 1; i >= 0; --i) {
@ -256,20 +101,185 @@ void ArithmeticFP32Coder::ComputeInOutStrides() {
}
outside_ *= arithmetic_parameter_->out_shape_[i];
}
ComputeStrides(arithmetic_parameter_->in_shape0_, arithmetic_parameter_->in_strides0_,
arithmetic_parameter_->ndim_);
ComputeStrides(arithmetic_parameter_->in_shape1_, arithmetic_parameter_->in_strides1_,
arithmetic_parameter_->ndim_);
ComputeStrides(arithmetic_parameter_->out_shape_, arithmetic_parameter_->out_strides_,
arithmetic_parameter_->ndim_);
}
int ret = RET_OK;
if (!IsScalarClac() && !IsBatchScalarCalc() && !IsBiasCalc()) {
ret = ConstTensorBroadCast(context);
}
return ret;
}
int ArithmeticFP32Coder::CheckDataType() {
auto in0_dataType = input_tensor_->data_type();
auto in1_dataType = filter_tensor_->data_type();
if (in0_dataType != in1_dataType) {
MS_LOG(ERROR) << "The dataTypes of input tensor0 and input tensor1 should be the same.";
return RET_ERROR;
}
return RET_OK;
}
void ArithmeticFP32Coder::ChooseArithmeticFunc(bool is_opt) {
if (input_tensor_->data_type() == kNumberTypeFloat32) {
if (is_opt) {
arithmetic_func_str_ = wrap_void(arithmetic_opt_run_);
} else {
arithmetic_func_str_ = wrap_void(arithmetic_run_);
}
} else if (input_tensor_->data_type() == kNumberTypeBool) {
arithmetic_func_str_ = wrap_void(arithmetic_run_bool_);
} else {
if (is_opt) {
arithmetic_func_str_ = wrap_void(arithmetic_opt_run_int_);
} else {
arithmetic_func_str_ = wrap_void(arithmetic_run_int_);
}
}
}
bool ArithmeticFP32Coder::IsScalarClac() {
return (arithmetic_parameter_->in_elements_num0_ == 1 || arithmetic_parameter_->in_elements_num1_ == 1) &&
(!arithmetic_opt_run_.empty());
}
bool ArithmeticFP32Coder::IsBatchScalarCalc() {
if (arithmetic_opt_run_.empty()) {
return false;
}
size_t break_axis = 0;
for (size_t i = 0; i < arithmetic_parameter_->ndim_; i++) {
if (arithmetic_parameter_->in_shape0_[i] != arithmetic_parameter_->in_shape1_[i]) {
break_axis = i;
break;
}
}
if (break_axis < arithmetic_parameter_->ndim_) {
for (size_t i = break_axis; i < arithmetic_parameter_->ndim_; i++) {
if (arithmetic_parameter_->in_shape1_[i] != 1) {
return false;
}
}
}
break_pos_ = break_axis;
return true;
}
bool ArithmeticFP32Coder::IsBiasCalc() {
int last_shape0 = arithmetic_parameter_->in_shape0_[arithmetic_parameter_->ndim_ - 1];
int last_shape1 = arithmetic_parameter_->in_shape1_[arithmetic_parameter_->ndim_ - 1];
if (arithmetic_parameter_->in_elements_num0_ > arithmetic_parameter_->in_elements_num1_) {
return arithmetic_parameter_->in_elements_num1_ == last_shape1 && last_shape0 == last_shape1;
} else if (arithmetic_parameter_->in_elements_num0_ < arithmetic_parameter_->in_elements_num1_) {
return arithmetic_parameter_->in_elements_num0_ == last_shape0 && last_shape0 == last_shape1;
}
return false;
}
void ArithmeticFP32Coder::FreeConstTileBuff() {
if (input0_broadcast_ && input0_ptr_ != nullptr) {
input0_ptr_ = nullptr;
input0_broadcast_ = false;
}
if (input1_broadcast_ && input1_ptr_ != nullptr) {
input1_ptr_ = nullptr;
input0_broadcast_ = false;
}
}
int ArithmeticFP32Coder::ConstTensorBroadCast(CoderContext *const context) {
// if const node need broadcast and all need-broadcast-node are const, broadcast in resize
if (!arithmetic_parameter_->broadcasting_) {
return RET_OK;
}
if (output_tensor_->Size() < 0) {
return RET_OK;
}
// need broadcast both input
if (arithmetic_parameter_->in_elements_num0_ != arithmetic_parameter_->out_elements_num_ &&
arithmetic_parameter_->in_elements_num1_ != arithmetic_parameter_->out_elements_num_) {
return RET_OK;
}
FreeConstTileBuff();
NNaclFp32Serializer init_code;
Collect(context,
{
"wrapper/fp32/arithmetic_fp32_wrapper.h",
},
{
"arithmetic_fp32_wrapper.c",
});
if (input_tensor_->IsConst() &&
arithmetic_parameter_->in_elements_num0_ != arithmetic_parameter_->out_elements_num_) {
input0_ptr_ = reinterpret_cast<float *>(
allocator_->Malloc(kNumberTypeFloat32, arithmetic_parameter_->out_elements_num_ * data_type_len_, kWorkspace));
MS_CHECK_PTR(input0_ptr_);
init_code.CodeArray("in_shape", arithmetic_parameter_->in_shape0_, arithmetic_parameter_->ndim_, true);
init_code.CodeArray("in_stride", arithmetic_parameter_->in_strides0_, arithmetic_parameter_->ndim_, true);
init_code.CodeArray("out_stride", arithmetic_parameter_->out_strides_, arithmetic_parameter_->ndim_, true);
init_code.CodeArray("multiple", arithmetic_parameter_->multiples0_, arithmetic_parameter_->ndim_, true);
init_code.CodeFunction("TileConstTensor", input_tensor_, input0_ptr_, arithmetic_parameter_->ndim_, "in_shape",
"in_stride", "out_stride", "multiple");
input0_broadcast_ = true;
arithmetic_parameter_->in_elements_num0_ = arithmetic_parameter_->out_elements_num_;
arithmetic_parameter_->broadcasting_ = false;
}
if (filter_tensor_->IsConst() &&
arithmetic_parameter_->in_elements_num1_ != arithmetic_parameter_->out_elements_num_) {
input1_ptr_ = reinterpret_cast<float *>(
allocator_->Malloc(kNumberTypeFloat32, arithmetic_parameter_->out_elements_num_ * data_type_len_, kWorkspace));
MS_CHECK_PTR(input1_ptr_);
init_code.CodeArray("in_shape", arithmetic_parameter_->in_shape1_, arithmetic_parameter_->ndim_, true);
init_code.CodeArray("in_stride", arithmetic_parameter_->in_strides1_, arithmetic_parameter_->ndim_, true);
init_code.CodeArray("out_stride", arithmetic_parameter_->out_strides_, arithmetic_parameter_->ndim_, true);
init_code.CodeArray("multiple", arithmetic_parameter_->multiples1_, arithmetic_parameter_->ndim_, true);
init_code.CodeFunction("TileConstTensor", filter_tensor_, input1_ptr_, arithmetic_parameter_->ndim_, "in_shape",
"in_stride", "out_stride", "multiple");
input1_broadcast_ = true;
arithmetic_parameter_->in_elements_num1_ = arithmetic_parameter_->out_elements_num_;
arithmetic_parameter_->broadcasting_ = false;
}
return RET_OK;
}
int ArithmeticFP32Coder::Prepare(CoderContext *const context) {
filter_tensor_ = input_tensors_.at(kWeightIndex);
MS_CHECK_PTR(filter_tensor_);
MS_CHECK_RET_CODE(CheckDataType(), "ArithmeticFP32Coder check datatype fail");
MS_CHECK_PTR(parameter_);
arithmetic_parameter_ = reinterpret_cast<ArithmeticParameter *>(parameter_);
auto primitive_type = arithmetic_parameter_->op_parameter_.type_;
if (primitive_type == schema::PrimitiveType_Eltwise) {
switch (arithmetic_parameter_->eltwise_mode_) {
case schema::EltwiseMode_PROD:
primitive_type = schema::PrimitiveType_MulFusion;
break;
case schema::EltwiseMode_SUM:
primitive_type = schema::PrimitiveType_AddFusion;
break;
case schema::EltwiseMode_MAXIMUM:
primitive_type = schema::PrimitiveType_Maximum;
break;
default:
MS_LOG(ERROR) << "Eltwise mode not support, mode:" << arithmetic_parameter_->eltwise_mode_;
return RET_ERROR;
}
}
InitRunFunction(primitive_type);
MS_CHECK_RET_CODE(ReSize(context), "do arithmetic ReSize fail!");
return RET_OK;
}
void ArithmeticFP32Coder::CollectFilesForFunc(CoderContext *const context) {
/**
* for nnacl's operator combine all arithmetic to nnalc/arithmetic.c
* this solution is not suitable for micro, for the size of package.
* */
// collect wrapper files
Collect(context,
{
"wrapper/fp32/arithmetic_fp32_wrapper.h",
},
{
"arithmetic_fp32_wrapper.c",
});
// for nnacl's operator combine all arithmetic to nnalc/arithmetic.c
// this solution is not suitable for micro, for the size of package.
if (arithmetic_opt_run_ == "ElementOptSub" || arithmetic_run_ == "ElementSub") {
Collect(context,
{
@ -324,57 +334,148 @@ void ArithmeticFP32Coder::CollectFilesForFunc(CoderContext *const context) {
}
}
int ArithmeticFP32Coder::ExecuteCode(const std::string &input0, const std::string &input1, const std::string &output,
int size, bool is_opt, CoderContext *const context,
NNaclFp32Serializer *const code) {
if (arithmetic_func_str_.empty()) {
return RET_ERROR;
}
code->CodeStruct("arithmetic_parameter", *arithmetic_parameter_);
code->CodeFunction("ArithmeticExecute", input0, input1, output, size, is_opt, arithmetic_func_type_,
arithmetic_func_str_, "&arithmetic_parameter");
context->AppendCode(code->str());
return RET_OK;
}
int ArithmeticFP32Coder::BatchScalarCalc(int task_id, CoderContext *const context, NNaclFp32Serializer *const code) {
if (break_pos_ < 1) {
return RET_ERROR;
}
int batch = arithmetic_parameter_->out_elements_num_ / arithmetic_parameter_->out_strides_[break_pos_ - 1];
int batch_per_thread = UP_DIV(batch, thread_num_);
int start_batch = batch_per_thread * task_id;
int end_batch = MSMIN(start_batch + batch_per_thread, batch);
int batch_size = end_batch - start_batch;
int stride0 = arithmetic_parameter_->in_strides0_[break_pos_ - 1] * data_type_len_;
int stride1 = arithmetic_parameter_->in_strides1_[break_pos_ - 1] * data_type_len_;
int out_stride = arithmetic_parameter_->out_strides_[break_pos_ - 1] * data_type_len_;
int offset0 = stride0 * start_batch;
int offset1 = stride1 * start_batch;
int out_offset = out_stride * start_batch;
arithmetic_wrapper_info_ = {offset0, stride0, offset1, stride1, out_offset, out_stride, arithmetic_func_type_};
code->CodeStruct("arithmetic_wrapper_info", arithmetic_wrapper_info_);
code->CodeStruct("arithmetic_parameter", *arithmetic_parameter_);
code->CodeFunction("BatchScalarCalc", wrap_uint8(input0_ptr_str_), wrap_uint8(input1_ptr_str_),
wrap_uint8(output_ptr_str_), batch_size, arithmetic_parameter_->out_strides_[break_pos_ - 1], true,
arithmetic_func_str_, "&arithmetic_wrapper_info", "&arithmetic_parameter");
context->AppendCode(code->str());
return RET_OK;
}
int ArithmeticFP32Coder::BiasCalc(int task_id, CoderContext *const context, NNaclFp32Serializer *const code) {
int last_shape = arithmetic_parameter_->out_shape_[arithmetic_parameter_->ndim_ - 1];
int batch = arithmetic_parameter_->out_elements_num_ / last_shape;
int batch_per_thread = UP_DIV(batch, thread_num_);
int start_batch = batch_per_thread * task_id;
int end_batch = MSMIN(start_batch + batch_per_thread, batch);
int batch_size = end_batch - start_batch;
int stride = last_shape * data_type_len_;
int offset = stride * start_batch;
code->CodeStruct("arithmetic_parameter", *arithmetic_parameter_);
if (arithmetic_parameter_->in_elements_num0_ > arithmetic_parameter_->in_elements_num1_) {
arithmetic_wrapper_info_ = {offset, stride, 0, 0, offset, stride, arithmetic_func_type_};
code->CodeStruct("arithmetic_wrapper_info", arithmetic_wrapper_info_);
code->CodeFunction("BatchScalarCalc", wrap_uint8(input0_ptr_str_), wrap_uint8(input1_ptr_str_),
wrap_uint8(output_ptr_str_), batch_size, last_shape, false, arithmetic_func_str_,
"&arithmetic_wrapper_info", "&arithmetic_parameter");
} else {
arithmetic_wrapper_info_ = {0, 0, offset, stride, offset, stride, arithmetic_func_type_};
code->CodeStruct("arithmetic_wrapper_info", arithmetic_wrapper_info_);
code->CodeFunction("BatchScalarCalc", wrap_uint8(input0_ptr_str_), wrap_uint8(input1_ptr_str_),
wrap_uint8(output_ptr_str_), batch_size, last_shape, false, arithmetic_func_str_,
"&arithmetic_wrapper_info", "&arithmetic_parameter");
}
context->AppendCode(code->str());
return RET_OK;
}
int ArithmeticFP32Coder::BroadcastRun(const std::string &input0, const std::string &input1, const std::string &output,
int dim, int out_count, int out_thread_stride, CoderContext *const context,
NNaclFp32Serializer *const code) {
code->CodeStruct("arithmetic_parameter", *arithmetic_parameter_);
code->CodeFunction("BroadcastRun", wrap_uint8(input0_ptr_str_), wrap_uint8(input1_ptr_str_),
wrap_uint8(output_ptr_str_), dim, out_count, out_thread_stride, break_pos_, data_type_len_,
arithmetic_func_type_, arithmetic_func_str_, "&arithmetic_parameter");
context->AppendCode(code->str());
return RET_OK;
}
int ArithmeticFP32Coder::DoCode(CoderContext *const context) {
ComputeInOutStrides();
int element_num = output_tensor_->ElementsNum();
MS_CHECK_TRUE(thread_num_ > 0, "thread_num_ is less than zero");
int stride = UP_DIV(element_num, thread_num_);
int count = MSMIN(stride, element_num - stride * kDefaultTaskId);
MS_CHECK_TRUE(!arithmetic_run_.empty(), "arithmetic_run function is nullptr!");
if (count <= 0) {
return RET_OK;
}
int offset = stride * kDefaultTaskId * data_type_len_;
input0_ptr_str_ = allocator_->GetRuntimeAddr(input0_ptr_);
input1_ptr_str_ = allocator_->GetRuntimeAddr(input1_ptr_);
if (!input0_broadcast_) {
input0_ptr_str_ = allocator_->GetRuntimeAddr(input_tensor_, true);
input1_ptr_str_ = allocator_->GetRuntimeAddr(filter_tensor_);
}
if (!input1_broadcast_) {
input0_ptr_str_ = allocator_->GetRuntimeAddr(input_tensor_);
input1_ptr_str_ = allocator_->GetRuntimeAddr(filter_tensor_, true);
}
output_ptr_str_ = allocator_->GetRuntimeAddr(output_tensor_);
NNaclFp32Serializer code;
CollectFilesForFunc(context);
if (arithmetic_parameter_->broadcasting_) {
stride = UP_DIV(outside_, thread_num_);
out_count_ = MSMIN(stride, outside_ - stride * kDefaultTaskId);
out_thread_stride_ = stride * kDefaultTaskId;
std::string input0_str = allocator_->GetRuntimeAddr(input_tensor_);
std::string input1_str = allocator_->GetRuntimeAddr(filter_tensor_);
std::string output_str = allocator_->GetRuntimeAddr(output_tensor_);
MS_CHECK_RET_CODE(BroadcastRun(input0_str, input1_str, output_str, 0, out_count_, out_thread_stride_, &code),
"do broad cast code failed!");
} else if (!arithmetic_opt_run_.empty()) {
code.CodeStruct("arithmetic_parameter", *arithmetic_parameter_);
if (IsScalarClac()) {
ChooseArithmeticFunc(true);
if (arithmetic_parameter_->in_elements_num0_ == 1) {
if (data_type_ == kDataTypeFloat) {
code.CodeFunction(arithmetic_opt_run_, input_tensor_, filter_tensor_, output_tensor_, count,
"&arithmetic_parameter");
} else {
code.CodeFunction(arithmetic_opt_run_int_, input_tensor_, filter_tensor_, output_tensor_, count,
"&arithmetic_parameter");
}
return ExecuteCode(wrap_uint8(input0_ptr_str_), wrap_offset(wrap_uint8(input1_ptr_str_), offset),
wrap_offset(wrap_uint8(output_ptr_str_), offset), count, true, context, &code);
} else if (arithmetic_parameter_->in_elements_num1_ == 1) {
if (data_type_ == kDataTypeFloat) {
code.CodeFunction(arithmetic_opt_run_, input_tensor_, filter_tensor_, output_tensor_, count,
"&arithmetic_parameter");
} else {
code.CodeFunction(arithmetic_opt_run_int_, input_tensor_, filter_tensor_, output_tensor_, count,
"&arithmetic_parameter");
}
} else {
MS_LOG(ERROR) << "arithmetic opt code run: at least one of inputs is scalar";
return RET_ERROR;
}
} else {
if (data_type_ == kDataTypeFloat) {
code.CodeFunction(arithmetic_run_, input_tensor_, filter_tensor_, output_tensor_, count);
} else {
code.CodeFunction(arithmetic_run_int_, input_tensor_, filter_tensor_, output_tensor_, count);
return ExecuteCode(wrap_offset(wrap_uint8(input0_ptr_str_), offset), wrap_void(input1_ptr_str_),
wrap_offset(wrap_uint8(output_ptr_str_), offset), count, true, context, &code);
}
}
MS_LOG(DEBUG) << "ArithmeticFP32Code has been called";
context->AppendCode(code.str());
return RET_OK;
// run opt function, every batch one of input is scalar
if (IsBatchScalarCalc()) {
ChooseArithmeticFunc(true);
return BatchScalarCalc(kDefaultTaskId, context, &code);
}
// each batch is eltwise calculation
if (IsBiasCalc()) {
ChooseArithmeticFunc(false);
return BiasCalc(kDefaultTaskId, context, &code);
}
// need broadcast in runtime
if (arithmetic_parameter_->broadcasting_) {
ChooseArithmeticFunc(false);
stride = UP_DIV(outside_, thread_num_);
int out_count = MSMIN(stride, outside_ - stride * kDefaultTaskId);
if (out_count <= 0) {
return RET_OK;
}
return BroadcastRun(input0_ptr_str_, input1_ptr_str_, output_ptr_str_, 0, out_count, stride * kDefaultTaskId,
context, &code);
}
// all elements eltwise calculation
ChooseArithmeticFunc(false);
return ExecuteCode(wrap_offset(wrap_uint8(input0_ptr_str_), offset), wrap_offset(wrap_uint8(input1_ptr_str_), offset),
wrap_offset(wrap_uint8(output_ptr_str_), offset), count, false, context, &code);
}
REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_AddFusion, CPUOpCoderCreator<ArithmeticFP32Coder>)

View File

@ -22,7 +22,7 @@
#include "coder/opcoders/op_coder.h"
#include "nnacl/fp32/arithmetic_fp32.h"
#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
#define DEFAULT_ARITHMETIC_NDIMS 10
#include "wrapper/fp32/arithmetic_fp32_wrapper.h"
namespace mindspore::lite::micro::nnacl {
using mindspore::schema::PrimitiveType_AddFusion;
@ -65,7 +65,19 @@ using mindspore::schema::PrimitiveType_Eltwise;
using mindspore::schema::PrimitiveType_Minimum;
using mindspore::schema::PrimitiveType_Mod;
class ArithmeticFP32Coder final : public OperatorCoder {
typedef struct {
int primitive_type_;
int activation_type_;
std::string func_;
std::string int_func_;
std::string bool_func_;
std::string opt_func_;
std::string opt_int_func_;
} ARITHMETIC_FUNC_INFO_FP32;
public:
ArithmeticFP32Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
const Model::Node *node, size_t node_index, Target target)
@ -78,15 +90,39 @@ class ArithmeticFP32Coder final : public OperatorCoder {
int DoCode(CoderContext *const context) override;
private:
int Init(CoderContext *const context);
int ReSize(CoderContext *const context);
int ExecuteCode(const std::string &input0, const std::string &input1, const std::string &output, int size,
bool is_opt, CoderContext *const context, NNaclFp32Serializer *const code);
void InitRunFunction(int primitive_type);
int CheckDataType();
void ChooseArithmeticFunc(bool is_opt);
bool IsScalarClac();
bool IsBatchScalarCalc();
bool IsBiasCalc();
void FreeConstTileBuff();
int ConstTensorBroadCast(CoderContext *const context);
void ComputeInOutStrides();
int BroadcastRun(const std::string &input0, const std::string &input1, const std::string &output, int dim,
int out_count, int out_thread_stride, NNaclFp32Serializer *const code);
int out_count, int out_thread_stride, CoderContext *const context, NNaclFp32Serializer *const code);
int BatchScalarCalc(int task_id, CoderContext *const context, NNaclFp32Serializer *const code);
int BiasCalc(int task_id, CoderContext *const context, NNaclFp32Serializer *const code);
void CollectFilesForFunc(CoderContext *const context);
private:
int break_pos_{0};
int outside_{0};
@ -95,10 +131,32 @@ class ArithmeticFP32Coder final : public OperatorCoder {
int out_count_{0};
int data_type_len_{0};
bool input0_broadcast_{false};
bool input1_broadcast_{false};
float *input0_ptr_{nullptr};
float *input1_ptr_{nullptr};
float *output_ptr_{nullptr};
ArithmeticParameter *arithmetic_parameter_{nullptr};
Tensor *filter_tensor_{nullptr};
ArithmeticFuncType arithmetic_func_type_{kArithmeticFuncUnknow};
ArithmeticWrapperInfo arithmetic_wrapper_info_{};
std::string input0_ptr_str_;
std::string input1_ptr_str_;
std::string output_ptr_str_;
std::string arithmetic_run_;
std::string arithmetic_run_int_;
@ -107,7 +165,9 @@ class ArithmeticFP32Coder final : public OperatorCoder {
std::string arithmetic_opt_run_int_;
LiteDataType data_type_{kDataTypeFloat};
std::string arithmetic_run_bool_;
std::string arithmetic_func_str_;
};
} // namespace mindspore::lite::micro::nnacl
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_FP32_ARITHMETIC_FP32_CODER_H_

View File

@ -28,7 +28,6 @@
using mindspore::schema::PrimitiveType_MatMul;
namespace mindspore::lite::micro::nnacl {
int MatMulFP32BaseCoder::ReSize() {
ResizeParameter();
thread_count_ = MSMIN(thread_num_, UP_DIV(params_->col_align_, col_tile_));
@ -45,8 +44,12 @@ int MatMulFP32BaseCoder::ReSize() {
int MatMulFP32BaseCoder::InitBiasData() {
if (input_tensors_.size() == 3) {
int max_bias_data = UP_ROUND(bias_tensor_->ElementsNum(), C16NUM);
int max_bias_data = params_->col_align_;
bias_pack_ptr_size_ = static_cast<size_t>(max_bias_data * sizeof(float));
if (bias_tensor_->ElementsNum() == 1) {
is_bias_broadcast_ = true;
}
ori_bias_pack_ptr_size_ = bias_tensor_->ElementsNum() * sizeof(float);
bias_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight));
MS_CHECK_PTR(bias_ptr_);
}
@ -123,8 +126,7 @@ int MatMulFP32BaseCoder::Init() {
int MatMulFP32BaseCoder::Prepare(CoderContext *const context) { return RET_OK; }
int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
// generate code .h .c
int MatMulFP32BaseCoder::CollectFilesForTarget(CoderContext *const context) {
Collect(context,
{
"nnacl/fp32/matmul_fp32.h",
@ -158,6 +160,11 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
"dequant_int8_to_fp32_wrapper.c",
});
}
return RET_OK;
}
int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
CollectFilesForTarget(context);
NNaclFp32Serializer code;
NNaclFp32Serializer init_code;
code.CodeStruct("mat_mul_parameter", *params_);
@ -165,7 +172,17 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
// do bias packing to init
if (bias_ptr_) {
init_code.CodeMallocExpression(bias_ptr_, bias_pack_ptr_size_);
init_code.CodeFunction("memcpy", bias_ptr_, bias_tensor_, bias_pack_ptr_size_);
init_code.CodeFunction("memset", bias_ptr_, 0, bias_pack_ptr_size_);
int max_bias_data = params_->col_align_;
if (is_bias_broadcast_) {
float broad_cast_data = (reinterpret_cast<float *>(bias_tensor_->data_c()))[0];
std::string bias_ptr_str = "((float *)(" + allocator_->GetRuntimeAddr(bias_ptr_) + "))";
init_code << "\tfor (int i = 0; i < " << max_bias_data << "; ++i) {\n";
init_code << "\t\t" << bias_ptr_str << "[i] = " << broad_cast_data << ";\n";
init_code << "\t}\n";
} else {
init_code.CodeFunction("memcpy", bias_ptr_, bias_tensor_, ori_bias_pack_ptr_size_);
}
}
// Get Tensor Pointer
@ -179,6 +196,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
if (!params_->a_const_) {
code.CodeFunction("InitMatrixA", input_tensor_, a_pack_ptr_, "&mat_mul_parameter", vec_matmul_);
init_code.CodeMallocExpression(b_pack_ptr_, b_pack_ptr_size_);
init_code.CodeFunction("memset", b_pack_ptr_, 0, b_pack_ptr_size_);
std::string b_src_str = b_str;
if (de_quant_flag_) {
// reuse to b_pack_str
@ -191,6 +209,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
}
if (!params_->b_const_) {
init_code.CodeMallocExpression(a_pack_str, a_pack_ptr_size_);
init_code.CodeFunction("memset", a_pack_ptr_, 0, a_pack_ptr_size_);
std::string a_src_str = a_str;
if (de_quant_flag_) {
// reuse to a_pack_str
@ -228,7 +247,6 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
params_->deep_, params_->row_, cur_oc, params_->col_, "OutType_Nhwc");
}
code << "\t\t}\n";
context->AppendCode(code.str());
context->AppendInitCode(init_code.str());
return RET_OK;

View File

@ -43,6 +43,7 @@ class MatMulFP32BaseCoder : public OperatorCoder {
int InitBufferB();
int InitMatrixA(const float *src_ptr);
int InitMatrixB(const float *src_ptr);
int CollectFilesForTarget(CoderContext *const context);
protected:
virtual int Init();
@ -64,8 +65,10 @@ class MatMulFP32BaseCoder : public OperatorCoder {
int thread_stride_{0};
int thread_count_{0};
size_t bias_pack_ptr_size_{0};
size_t ori_bias_pack_ptr_size_{0};
size_t a_pack_ptr_size_{0};
size_t b_pack_ptr_size_{0};
bool is_bias_broadcast_{false};
};
} // namespace mindspore::lite::micro::nnacl
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_FP32_MATMUL_FP32_BASE_CODER_H_

View File

@ -74,13 +74,13 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const ConvParamete
}
void NNaclFp32Serializer::CodeStruct(const std::string &name, const MatMulParameter &mat_mul_parameter) {
CodeBaseStruct("MatMulParameter", name, mat_mul_parameter.op_parameter_, mat_mul_parameter.has_bias_,
mat_mul_parameter.row_, mat_mul_parameter.col_, mat_mul_parameter.row_4_, mat_mul_parameter.row_6_,
mat_mul_parameter.row_12_, mat_mul_parameter.row_16_, mat_mul_parameter.row_align_,
mat_mul_parameter.col_4_, mat_mul_parameter.col_8_, mat_mul_parameter.col_align_,
mat_mul_parameter.deep_, mat_mul_parameter.deep_4_, mat_mul_parameter.deep_16_,
mat_mul_parameter.batch, mat_mul_parameter.a_transpose_, mat_mul_parameter.b_transpose_,
mat_mul_parameter.a_const_, mat_mul_parameter.b_const_, mat_mul_parameter.act_type_);
CodeBaseStruct(
"MatMulParameter", name, mat_mul_parameter.op_parameter_, mat_mul_parameter.has_bias_, mat_mul_parameter.row_,
mat_mul_parameter.col_, mat_mul_parameter.row_4_, mat_mul_parameter.row_6_, mat_mul_parameter.row_12_,
mat_mul_parameter.row_16_, mat_mul_parameter.row_align_, mat_mul_parameter.col_4_, mat_mul_parameter.col_8_,
mat_mul_parameter.col_align_, mat_mul_parameter.deep_, mat_mul_parameter.deep_4_, mat_mul_parameter.deep_16_,
mat_mul_parameter.batch, mat_mul_parameter.a_transpose_, mat_mul_parameter.b_transpose_, mat_mul_parameter.a_const_,
mat_mul_parameter.b_const_, mat_mul_parameter.act_type_, mat_mul_parameter.use_axis_, mat_mul_parameter.axis_);
}
void NNaclFp32Serializer::CodeStruct(const std::string &name, const ScaleParameter &scale_parameter) {
@ -143,4 +143,11 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const StridedSlice
strided_slice_parameter.begins_mask_, strided_slice_parameter.ellipsisMask_,
strided_slice_parameter.newAxisMask_, strided_slice_parameter.shrinkAxisMask_);
}
void NNaclFp32Serializer::CodeStruct(const std::string &name, const ArithmeticWrapperInfo &arithmetic_wrapper_info) {
CodeBaseStruct("ArithmeticWrapperInfo", name, arithmetic_wrapper_info.offset0_, arithmetic_wrapper_info.stride0_,
arithmetic_wrapper_info.offset1_, arithmetic_wrapper_info.stride1_,
arithmetic_wrapper_info.out_offset_, arithmetic_wrapper_info.out_stride_,
arithmetic_wrapper_info.arithmetic_func_type_);
}
} // namespace mindspore::lite::micro::nnacl

View File

@ -34,6 +34,7 @@
#include "wrapper/fp32/dequant_int8_to_fp32_wrapper.h"
#include "nnacl/fp32/exp_fp32.h"
#include "nnacl/fp32/strided_slice_fp32.h"
#include "wrapper/fp32/arithmetic_fp32_wrapper.h"
namespace mindspore::lite::micro::nnacl {
class NNaclFp32Serializer : public Serializer {
@ -55,6 +56,7 @@ class NNaclFp32Serializer : public Serializer {
void CodeStruct(const std::string &name, const SpliceParameter &splice_parameter);
void CodeStruct(const std::string &name, const ExpParameter &exp_parameter);
void CodeStruct(const std::string &name, const StridedSliceParameter &strided_slice_parameter);
void CodeStruct(const std::string &name, const ArithmeticWrapperInfo &arithmetic_wrapper_info);
};
} // namespace mindspore::lite::micro::nnacl

View File

@ -0,0 +1,83 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "wrapper/fp32/arithmetic_fp32_wrapper.h"
void TileConstTensor(const float *in_data, float *out_data, size_t ndim, const int *in_shape, const int *in_strides,
const int *out_strides, const int *multiple) {
TileOneDimensionFp32(in_data, out_data, 0, ndim, in_shape, in_strides, out_strides, multiple);
}
void ArithmeticExecute(const void *input0, const void *input1, void *output, int size, bool is_opt,
ArithmeticFuncType arithmetic_func_type, const void *arithmetic_func,
const ArithmeticParameter *param) {
if (arithmetic_func_type == kArithmeticFuncFloat) {
if (is_opt) {
ArithmeticOptRun arithmetic_opt_run = (ArithmeticOptRun)(arithmetic_func);
arithmetic_opt_run((const float *)(input0), (const float *)(input1), (float *)(output), size, param);
} else {
ArithmeticRun arithmetic_run = (ArithmeticRun)(arithmetic_func);
arithmetic_run((const float *)(input0), (const float *)(input1), (float *)(output), size);
}
} else if (arithmetic_func_type == kArithmeticFuncBool) {
ArithmeticBoolRun arithmetic_run_bool = (ArithmeticBoolRun)(arithmetic_func);
arithmetic_run_bool((const bool *)(input0), (const bool *)(input1), (bool *)(output), size);
} else if (arithmetic_func_type == kArithmeticFuncInt) {
if (is_opt) {
ArithmeticOptIntRun arithmetic_opt_run_int = (ArithmeticOptIntRun)(arithmetic_func);
arithmetic_opt_run_int((const int *)(input0), (const int *)(input1), (int *)(output), size, param);
} else {
ArithmeticIntRun arithmetic_run_int = (ArithmeticIntRun)(arithmetic_func);
arithmetic_run_int((const int *)(input0), (const int *)(input1), (int *)(output), size);
}
}
}
void BatchScalarCalc(const void *input0, const void *input1, void *output, int batch_size, int size, bool is_opt,
const void *arithmetic_func, const ArithmeticWrapperInfo *wrapper_info,
const ArithmeticParameter *param) {
int offset0 = wrapper_info->offset0_;
int offset1 = wrapper_info->offset1_;
int out_offset = wrapper_info->out_offset_;
int stride0 = wrapper_info->stride0_;
int stride1 = wrapper_info->stride1_;
int out_stride = wrapper_info->out_stride_;
for (int i = 0; i < batch_size; i++) {
ArithmeticExecute((const uint8_t *)(input0) + offset0, (const uint8_t *)(input1) + offset1,
(uint8_t *)(output) + out_offset, size, is_opt, wrapper_info->arithmetic_func_type_,
arithmetic_func, param);
offset0 += stride0;
offset1 += stride1;
out_offset += out_stride;
}
}
void BroadcastRun(const void *input0, const void *input1, void *output, int dim, int out_count, int out_thread_stride,
int break_pos, int data_type_len, ArithmeticFuncType arithmetic_func_type,
const void *arithmetic_func, const ArithmeticParameter *param) {
if (dim > break_pos) {
int offset = out_thread_stride * data_type_len;
ArithmeticExecute((const uint8_t *)(input0) + offset, (const uint8_t *)(input1) + offset,
(uint8_t *)(output) + offset, out_count, false, arithmetic_func_type, arithmetic_func, param);
}
int offset_size[] = {param->in_strides0_[dim] * data_type_len, param->in_strides1_[dim] * data_type_len,
param->out_strides_[dim] * data_type_len};
for (int i = 0; i < param->out_shape_[dim]; ++i) {
int pos0_ = param->in_shape0_[dim] == 1 ? 0 : i;
int pos1_ = param->in_shape1_[dim] == 1 ? 0 : i;
BroadcastRun((const uint8_t *)(input0) + pos0_ * offset_size[0], (const uint8_t *)(input1) + pos1_ * offset_size[1],
(uint8_t *)(output) + i * offset_size[2], dim + 1, out_count, out_thread_stride, break_pos,
data_type_len, arithmetic_func_type, arithmetic_func, param);
}
}

View File

@ -0,0 +1,66 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_ARITHMETIC_FP32_WRAPPER_H_
#define MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_ARITHMETIC_FP32_WRAPPER_H_
#include "nnacl/fp32/arithmetic_fp32.h"
#include <string.h>
#ifdef __cplusplus
extern "C" {
#endif
typedef enum ArithmeticFuncType {
kArithmeticFuncFloat = 0,
kArithmeticFuncBool = 1,
kArithmeticFuncInt = 2,
kArithmeticFuncUnknow = 3,
} ArithmeticFuncType;
typedef struct ArithmeticWrapperInfo {
int offset0_;
int stride0_;
int offset1_;
int stride1_;
int out_offset_;
int out_stride_;
ArithmeticFuncType arithmetic_func_type_;
} ArithmeticWrapperInfo;
typedef int (*ArithmeticRun)(const float *input0, const float *input1, float *output, const int element_size);
typedef int (*ArithmeticOptRun)(const float *input0, const float *input1, float *output, const int element_size,
const ArithmeticParameter *param);
typedef int (*ArithmeticIntRun)(const int *input0, const int *input1, int *output, const int element_size);
typedef int (*ArithmeticOptIntRun)(const int *input0, const int *input1, int *output, const int element_size,
const ArithmeticParameter *param);
typedef int (*ArithmeticBoolRun)(const bool *input0, const bool *input1, bool *output, const int element_size);
void ArithmeticExecute(const void *input0, const void *input1, void *output, int size, bool is_opt,
ArithmeticFuncType arithmetic_func_type, const void *arithmetic_func,
const ArithmeticParameter *param);
void TileConstTensor(const float *in_data, float *out_data, size_t ndim, const int *in_shape, const int *in_strides,
const int *out_strides, const int *multiple);
void BatchScalarCalc(const void *input0, const void *input1, void *output, int batch_size, int size, bool is_opt,
const void *arithmetic_func, const ArithmeticWrapperInfo *wrapper_info,
const ArithmeticParameter *param);
void BroadcastRun(const void *input0, const void *input1, void *output, int dim, int out_count, int out_thread_stride,
int break_pos, int data_type_len, ArithmeticFuncType arithmetic_func_type,
const void *arithmetic_func, const ArithmeticParameter *param);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_ARITHMETIC_FP32_WRAPPER_H_

View File

@ -11,6 +11,7 @@ include(${TOP_DIR}/cmake/external_libs/gtest.cmake)
include(${MICRO_DIR}/cmake/file_list.cmake)
include(${MICRO_DIR}/cmake/package_wrapper.cmake)
include_directories(${NNACL_DIR}/../)
include_directories(${TOP_DIR})
include_directories(${TOP_DIR}/mindspore/core/)
include_directories(${LITE_DIR})

View File

@ -90,17 +90,52 @@ int MatmulFp32BaseCPUKernel::InitBufferB() {
return RET_OK;
}
int MatmulFp32BaseCPUKernel::CalBroadCastBiasDataElements() {
lite::Tensor *bias_tensor = in_tensors_.at(2);
int max_bias_data = UP_ROUND(bias_tensor->ElementsNum(), C16NUM);
if (!params_->b_const_) {
MS_LOG(WARNING) << "matmul do not support broadcast bias data";
} else {
lite::Tensor *const_tensor = in_tensors_.at(1);
size_t shape_size = const_tensor->shape().size();
if (shape_size < kBiasIndex) {
return max_bias_data;
}
if (params_->b_transpose_) {
max_bias_data = UP_ROUND(const_tensor->shape()[shape_size - kBiasIndex], C16NUM);
} else {
max_bias_data = UP_ROUND(const_tensor->shape()[shape_size - kWeightIndex], C16NUM);
}
}
return max_bias_data;
}
int MatmulFp32BaseCPUKernel::InitBiasData() {
if (in_tensors_.size() == 3) {
auto bias_tensor = in_tensors_[2];
int max_bias_data = UP_ROUND(bias_tensor->ElementsNum(), C16NUM);
bias_ptr_ = reinterpret_cast<float *>(malloc(max_bias_data * sizeof(float)));
if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_ptr_ failed";
return RET_ERROR;
// whether to broadcast bias data
if (bias_tensor->ElementsNum() == 1) {
max_bias_data = CalBroadCastBiasDataElements();
bias_ptr_ = reinterpret_cast<float *>(malloc(max_bias_data * sizeof(float)));
if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_ptr_ failed";
return RET_ERROR;
}
float broadcast_data = (reinterpret_cast<float *>(bias_tensor->data_c()))[0];
// broadcast bias data
for (int i = 0; i < max_bias_data; ++i) {
bias_ptr_[i] = broadcast_data;
}
} else {
bias_ptr_ = reinterpret_cast<float *>(malloc(max_bias_data * sizeof(float)));
if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_ptr_ failed";
return RET_ERROR;
}
memset(bias_ptr_, 0, max_bias_data * sizeof(float));
memcpy(bias_ptr_, bias_tensor->data_c(), bias_tensor->ElementsNum() * sizeof(float));
}
memset(bias_ptr_, 0, max_bias_data * sizeof(float));
memcpy(bias_ptr_, bias_tensor->data_c(), bias_tensor->ElementsNum() * sizeof(float));
}
return RET_OK;
}

View File

@ -56,6 +56,7 @@ class MatmulFp32BaseCPUKernel : public LiteKernel {
void ResizeParameter();
void FreeResizeBufA();
void FreeResizeBufB();
int CalBroadCastBiasDataElements();
protected:
MatMulParameter *params_ = nullptr;

View File

@ -6,7 +6,6 @@ set(CCSRC_SRC
${CCSRC_DIR}/backend/optimizer/common/visit.cc
${CCSRC_DIR}/backend/optimizer/common/optimizer.cc
)
set(NNACL_DIR ${CCSRC_DIR}/backend/kernel_compiler/cpu/nnacl)
include(${TOP_DIR}/cmake/external_libs/glog.cmake)
include_directories(${TOP_DIR}/mindspore/ccsrc/backend/kernel_compiler/cpu)