forked from mindspore-Ecosystem/mindspore
fix pdr_pdr_adf
This commit is contained in:
parent
99a5dacdc7
commit
dd51f4e67f
|
@ -109,6 +109,7 @@ endif()
|
|||
string(REPLACE "/mindspore/lite" "" TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
set(CORE_DIR ${TOP_DIR}/mindspore/core)
|
||||
set(CCSRC_DIR ${TOP_DIR}/mindspore/ccsrc)
|
||||
set(NNACL_DIR ${CCSRC_DIR}/backend/kernel_compiler/cpu/nnacl)
|
||||
include_directories(${TOP_DIR})
|
||||
include_directories(${CORE_DIR})
|
||||
include_directories(${CORE_DIR}/ir)
|
||||
|
|
|
@ -6,7 +6,7 @@ option(PLATFORM_ARM32 "build operator for android arm 32" off)
|
|||
option(PLATFORM_ARM64 "build operator for android arm 64" off)
|
||||
|
||||
string(REPLACE "/mindspore/lite/micro" "" TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
set(NNACL_DIR ${TOP_DIR}/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl)
|
||||
include_directories(${CMAKE_BINARY_DIR})
|
||||
include(${TOP_DIR}/cmake/utils.cmake)
|
||||
include(${TOP_DIR}/cmake/dependency_utils.cmake)
|
||||
|
|
|
@ -177,7 +177,6 @@ set(LITE_SRC
|
|||
${LITE_DIR}/tools/common/flag_parser.cc
|
||||
)
|
||||
|
||||
set(NNACL_DIR ${TOP_DIR}/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl)
|
||||
set(LITE_KERNEL_SRC
|
||||
### nnacl
|
||||
${NNACL_DIR}/common_func.c
|
||||
|
@ -186,6 +185,7 @@ set(LITE_KERNEL_SRC
|
|||
${NNACL_DIR}/base/slice_base.c
|
||||
${NNACL_DIR}/fp32/winograd_utils.c
|
||||
${NNACL_DIR}/fp32/pack_fp32.c
|
||||
${NNACL_DIR}/fp32/arithmetic_fp32.c
|
||||
${NNACL_DIR}/int8/quantize.c
|
||||
${NNACL_DIR}/int8/pack_int8.c
|
||||
${NNACL_DIR}/int8/matmul_int8.c
|
||||
|
|
|
@ -6,6 +6,7 @@ set(WRAPPER_SRC
|
|||
${WRAPPER_DIR}/base/detection_post_process_base_wrapper.c
|
||||
${WRAPPER_DIR}/base/optimize_handler_wrapper.c
|
||||
${WRAPPER_DIR}/fp32/matmul_fp32_wrapper.c
|
||||
${WRAPPER_DIR}/fp32/arithmetic_fp32_wrapper.c
|
||||
${WRAPPER_DIR}/int8/matmul_int8_wrapper.c
|
||||
${WRAPPER_DIR}/int8/add_int8_wrapper.c
|
||||
${WRAPPER_DIR}/int8/concat_int8_wrapper.c
|
||||
|
|
|
@ -18,7 +18,7 @@ include_directories(${3RD_DIR}/flatbuffers/include)
|
|||
#include ms
|
||||
include_directories(${TOP_DIR}/)
|
||||
include_directories(${TOP_DIR}/mindspore/core/)
|
||||
include_directories(${TOP_DIR}/mindspore/ccsrc/backend/kernel_compiler/cpu)
|
||||
include_directories(${NNACL_DIR}/../)
|
||||
include_directories(${LITE_DIR})
|
||||
include_directories(${MICRO_DIR})
|
||||
#include coder
|
||||
|
|
|
@ -15,8 +15,6 @@
|
|||
*/
|
||||
#include "coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.h"
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <type_traits>
|
||||
#include "coder/opcoders/file_collector.h"
|
||||
#include "nnacl/fp32/arithmetic_fp32.h"
|
||||
#include "coder/opcoders/parallel.h"
|
||||
|
@ -24,229 +22,76 @@
|
|||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
|
||||
int ArithmeticFP32Coder::Init(CoderContext *const context) {
|
||||
filter_tensor_ = input_tensors_.at(kWeightIndex);
|
||||
MS_CHECK_PTR(filter_tensor_);
|
||||
if (input_tensor_->data_type() == kNumberTypeFloat32 || input_tensor_->data_type() == kNumberTypeFloat16) {
|
||||
data_type_ = kDataTypeFloat;
|
||||
namespace {
|
||||
std::string wrap_void(const std::string &a) { return "(void *)(" + a + ")"; }
|
||||
std::string wrap_uint8(const std::string &a) { return "(uint8_t *)(" + a + ")"; }
|
||||
std::string wrap_offset(const std::string &a, int offset) { return "(" + a + "+" + std::to_string(offset) + ")"; }
|
||||
} // namespace
|
||||
|
||||
void ArithmeticFP32Coder::InitRunFunction(int primitive_type) {
|
||||
ARITHMETIC_FUNC_INFO_FP32 fun_table[] = {
|
||||
{PrimitiveType_MulFusion, schema::ActivationType_RELU, "ElementMulRelu", "ElementMulReluInt", "",
|
||||
"ElementOptMulRelu", "ElementOptMulReluInt"},
|
||||
{PrimitiveType_MulFusion, schema::ActivationType_RELU6, "ElementMulRelu6", "ElementMulRelu6Int", "",
|
||||
"ElementOptMulRelu6", "ElementOptMulRelu6Int"},
|
||||
{PrimitiveType_MulFusion, schema::ActivationType_NO_ACTIVATION, "ElementMul", "ElementMulInt", "", "ElementOptMul",
|
||||
"ElementOptMulInt"},
|
||||
{PrimitiveType_AddFusion, schema::ActivationType_RELU, "ElementAddRelu", "", "", "ElementOptAddRelu", ""},
|
||||
{PrimitiveType_AddFusion, schema::ActivationType_RELU6, "ElementAddRelu6", "", "", "ElementOptAddRelu6", ""},
|
||||
{PrimitiveType_AddFusion, schema::ActivationType_NO_ACTIVATION, "ElementAdd", "ElementAddInt", "", "ElementOptAdd",
|
||||
"ElementOptAddInt"},
|
||||
{PrimitiveType_SubFusion, schema::ActivationType_RELU, "ElementSubRelu", "", "", "ElementOptSubRelu", ""},
|
||||
{PrimitiveType_SubFusion, schema::ActivationType_RELU6, "ElementSubRelu6", "", "", "ElementOptSubRelu6", ""},
|
||||
{PrimitiveType_SubFusion, schema::ActivationType_NO_ACTIVATION, "ElementSub", "ElementSubInt", "", "ElementOptSub",
|
||||
"ElementOptSubInt"},
|
||||
{PrimitiveType_DivFusion, schema::ActivationType_RELU, "ElementDivRelu", "", "", "ElementOptDivRelu", ""},
|
||||
{PrimitiveType_DivFusion, schema::ActivationType_RELU6, "ElementDivRelu6", "", "", "ElementOptDivRelu6", ""},
|
||||
{PrimitiveType_DivFusion, schema::ActivationType_NO_ACTIVATION, "ElementDiv", "", "", "ElementOptDiv",
|
||||
"ElementOptDivInt"},
|
||||
{PrimitiveType_RealDiv, schema::ActivationType_RELU, "ElementDivRelu", "", "", "ElementOptDivRelu", ""},
|
||||
{PrimitiveType_RealDiv, schema::ActivationType_RELU6, "ElementDivRelu6", "", "", "ElementOptDivRelu6", ""},
|
||||
{PrimitiveType_RealDiv, schema::ActivationType_NO_ACTIVATION, "ElementDiv", "", "", "ElementOptDiv",
|
||||
"ElementOptDivInt"},
|
||||
{PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, "ElementLogicalAnd", "ElementLogicalAndInt",
|
||||
"ElementLogicalAndBool", "", ""},
|
||||
{PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, "ElementLogicalOr", "", "ElementLogicalOrBool", "",
|
||||
""},
|
||||
{PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, "ElementMaximum", "ElementMaximumInt", "", "", ""},
|
||||
{PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, "ElementMinimum", "ElementMinimumInt", "", "", ""},
|
||||
{PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, "ElementFloorMod", "ElementFloorModInt", "", "", ""},
|
||||
{PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, "ElementFloorDiv", "ElementFloorDivInt", "", "", ""},
|
||||
{PrimitiveType_Mod, schema::ActivationType_NO_ACTIVATION, "ElementMod", "ElementModInt", "", "ElementOptMod",
|
||||
"ElementOptModInt"},
|
||||
{PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, "ElementSquaredDifference", "", "", "",
|
||||
""}};
|
||||
|
||||
size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_FUNC_INFO_FP32);
|
||||
for (size_t i = 0; i < length; i++) {
|
||||
if (fun_table[i].primitive_type_ == primitive_type &&
|
||||
fun_table[i].activation_type_ == arithmetic_parameter_->activation_type_) {
|
||||
arithmetic_run_ = fun_table[i].func_;
|
||||
arithmetic_run_int_ = fun_table[i].int_func_;
|
||||
arithmetic_run_bool_ = fun_table[i].bool_func_;
|
||||
arithmetic_opt_run_ = fun_table[i].opt_func_;
|
||||
arithmetic_opt_run_int_ = fun_table[i].opt_int_func_;
|
||||
return;
|
||||
}
|
||||
}
|
||||
TypeId input_type_id = input_tensor_->data_type();
|
||||
data_type_len_ = lite::DataTypeSize(input_tensor_->data_type());
|
||||
if (input_type_id == kNumberTypeFloat32 || input_type_id == kNumberTypeFloat) {
|
||||
arithmetic_func_type_ = kArithmeticFuncFloat;
|
||||
} else if (input_type_id == kNumberTypeBool) {
|
||||
arithmetic_func_type_ = kArithmeticFuncBool;
|
||||
} else if (input_type_id == kNumberTypeInt || input_type_id == kNumberTypeInt32) {
|
||||
arithmetic_func_type_ = kArithmeticFuncInt;
|
||||
} else {
|
||||
data_type_ = kDataTypeInt;
|
||||
arithmetic_func_type_ = kArithmeticFuncUnknow;
|
||||
}
|
||||
arithmetic_parameter_->in_elements_num0_ = input_tensor_->ElementsNum();
|
||||
arithmetic_parameter_->in_elements_num1_ = filter_tensor_->ElementsNum();
|
||||
arithmetic_parameter_->out_elements_num_ = output_tensor_->ElementsNum();
|
||||
for (size_t i = 0; i < input_tensor_->shape().size(); i++) {
|
||||
if (arithmetic_parameter_->in_shape0_[i] == -1) {
|
||||
MS_CHECK_RET_CODE(
|
||||
memcpy_s(arithmetic_parameter_->in_shape0_, DEFAULT_ARITHMETIC_NDIMS * sizeof(int),
|
||||
static_cast<void *>(input_tensor_->shape().data()), input_tensor_->shape().size() * sizeof(int)),
|
||||
"memcpy_s in shape0 failed!");
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < filter_tensor_->shape().size(); i++) {
|
||||
if (arithmetic_parameter_->in_shape1_[i] == -1) {
|
||||
MS_CHECK_RET_CODE(
|
||||
memcpy_s(arithmetic_parameter_->in_shape1_, DEFAULT_ARITHMETIC_NDIMS * sizeof(int),
|
||||
static_cast<void *>(filter_tensor_->shape().data()), filter_tensor_->shape().size() * sizeof(int)),
|
||||
"memcpy_s in shape1 failed!");
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < output_tensor_->shape().size(); i++) {
|
||||
if (arithmetic_parameter_->out_shape_[i] == -1) {
|
||||
MS_CHECK_RET_CODE(
|
||||
memcpy_s(arithmetic_parameter_->out_shape_, DEFAULT_ARITHMETIC_NDIMS * sizeof(int),
|
||||
static_cast<void *>(output_tensor_->shape().data()), output_tensor_->shape().size() * sizeof(int)),
|
||||
"memcpy_s in out shape failed!");
|
||||
}
|
||||
}
|
||||
|
||||
if (arithmetic_parameter_->in_elements_num0_ == 1 || arithmetic_parameter_->in_elements_num1_ == 1) {
|
||||
switch (arithmetic_parameter_->op_parameter_.type_) {
|
||||
case PrimitiveType_MulFusion:
|
||||
switch (arithmetic_parameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_parameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = "ElementOptMulRelu";
|
||||
arithmetic_opt_run_int_ = "ElementOptMulReluInt";
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_parameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = "ElementOptMulRelu6";
|
||||
arithmetic_opt_run_int_ = "ElementOptMulRelu6Int";
|
||||
break;
|
||||
default:
|
||||
arithmetic_parameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = "ElementOptMul";
|
||||
arithmetic_opt_run_int_ = "ElementOptMulInt";
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_AddFusion:
|
||||
switch (arithmetic_parameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_parameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = "ElementOptAddRelu";
|
||||
arithmetic_opt_run_int_ = "ElementOptAddReluInt";
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_parameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = "ElementOptAddRelu6";
|
||||
arithmetic_opt_run_int_ = "ElementOptAddRelu6Int";
|
||||
break;
|
||||
default:
|
||||
arithmetic_parameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = "ElementOptAdd";
|
||||
arithmetic_opt_run_int_ = "ElementOptAddInt";
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_SubFusion:
|
||||
switch (arithmetic_parameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_parameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = "ElementOptSubRelu";
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_parameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = "ElementOptSubRelu6";
|
||||
break;
|
||||
default:
|
||||
arithmetic_parameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = "ElementOptSub";
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticFP32Coder::BroadcastRun(const std::string &input0, const std::string &input1, const std::string &output,
|
||||
int dim, int out_count, int out_thread_stride, NNaclFp32Serializer *const code) {
|
||||
if (dim > break_pos_) {
|
||||
if (data_type_ == kDataTypeInt) {
|
||||
*code << "\t\t" << arithmetic_run_int_ << "(((" << input0 << ") + " << out_thread_stride << "), ((" << input1
|
||||
<< ") + " << out_thread_stride << "), ((" << output << ") + " << out_thread_stride << "), " << out_count
|
||||
<< ");\n";
|
||||
|
||||
} else {
|
||||
*code << "\t\t" << arithmetic_run_ << "(((" << input0 << ") + " << out_thread_stride << "), ((" << input1
|
||||
<< ") + " << out_thread_stride << "), ((" << output << ") + " << out_thread_stride << "), " << out_count
|
||||
<< ");\n";
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
for (int i = 0; i < arithmetic_parameter_->out_shape_[dim]; ++i) {
|
||||
int pos0_ = arithmetic_parameter_->in_shape0_[dim] == 1 ? 0 : i;
|
||||
int pos1_ = arithmetic_parameter_->in_shape1_[dim] == 1 ? 0 : i;
|
||||
int error_code = BroadcastRun(input0 + "+" + std::to_string(pos0_ * arithmetic_parameter_->in_strides0_[dim]),
|
||||
input1 + "+" + std::to_string(pos1_ * arithmetic_parameter_->in_strides1_[dim]),
|
||||
output + "+" + std::to_string(i * arithmetic_parameter_->out_strides_[dim]), dim + 1,
|
||||
out_count, out_thread_stride, code);
|
||||
if (error_code != RET_OK) {
|
||||
return error_code;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticFP32Coder::Prepare(CoderContext *const context) {
|
||||
if (parameter_ == nullptr) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
arithmetic_parameter_ = reinterpret_cast<ArithmeticParameter *>(parameter_);
|
||||
std::map<int, std::function<void()>> type_setters = {
|
||||
{PrimitiveType_MulFusion,
|
||||
[this]() {
|
||||
switch (arithmetic_parameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_run_ = "ElementMulRelu";
|
||||
arithmetic_run_int_ = "ElementMulReluInt";
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_run_ = "ElementMulRelu6";
|
||||
arithmetic_run_int_ = "ElementMulRelu6Int";
|
||||
break;
|
||||
default:
|
||||
arithmetic_run_ = "ElementMul";
|
||||
arithmetic_run_int_ = "ElementMulInt";
|
||||
break;
|
||||
}
|
||||
}},
|
||||
{PrimitiveType_AddFusion,
|
||||
[this]() {
|
||||
switch (arithmetic_parameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_run_ = "ElementAddRelu";
|
||||
arithmetic_run_int_ = "ElementAddReluInt";
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_run_ = "ElementAddRelu6";
|
||||
arithmetic_run_int_ = "ElementAddRelu6Int";
|
||||
break;
|
||||
default:
|
||||
arithmetic_run_ = "ElementAdd";
|
||||
arithmetic_run_int_ = "ElementAddInt";
|
||||
break;
|
||||
}
|
||||
}},
|
||||
{PrimitiveType_SubFusion,
|
||||
[this]() {
|
||||
switch (arithmetic_parameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_run_ = "ElementSubRelu";
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_run_ = "ElementSubRelu6";
|
||||
break;
|
||||
default:
|
||||
arithmetic_run_ = "ElementSub";
|
||||
break;
|
||||
}
|
||||
}},
|
||||
{PrimitiveType_DivFusion,
|
||||
[this]() {
|
||||
switch (arithmetic_parameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_run_ = "ElementDivRelu";
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_run_ = "ElementDivRelu6";
|
||||
break;
|
||||
default:
|
||||
arithmetic_run_ = "ElementDiv";
|
||||
break;
|
||||
}
|
||||
}},
|
||||
{PrimitiveType_LogicalAnd, [this]() { arithmetic_run_ = "ElementLogicalAnd"; }},
|
||||
{PrimitiveType_LogicalOr, [this]() { arithmetic_run_ = "ElementLogicalOr"; }},
|
||||
{PrimitiveType_Maximum, [this]() { arithmetic_run_ = "ElementMaximum"; }},
|
||||
{PrimitiveType_Minimum, [this]() { arithmetic_run_ = "ElementMinimum"; }},
|
||||
{PrimitiveType_FloorDiv, [this]() { arithmetic_run_ = "ElementFloorDiv"; }},
|
||||
{PrimitiveType_FloorMod, [this]() { arithmetic_run_ = "ElementFloorMod"; }},
|
||||
{PrimitiveType_Equal, [this]() { arithmetic_run_ = "ElementEqual"; }},
|
||||
{PrimitiveType_NotEqual, [this]() { arithmetic_run_ = "ElementNotEqual"; }},
|
||||
{PrimitiveType_Less, [this]() { arithmetic_run_ = "ElementLess"; }},
|
||||
{PrimitiveType_LessEqual, [this]() { arithmetic_run_ = "ElementLessEqual"; }},
|
||||
{PrimitiveType_Greater, [this]() { arithmetic_run_ = "ElementGreater"; }},
|
||||
{PrimitiveType_GreaterEqual, [this]() { arithmetic_run_ = "ElementGreaterEqual"; }},
|
||||
{PrimitiveType_SquaredDifference, [this]() { arithmetic_run_ = "ElementSquaredDifference"; }},
|
||||
};
|
||||
auto iter = type_setters.find(parameter_->type_);
|
||||
if (iter != type_setters.end()) {
|
||||
iter->second();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Error Operator type " << parameter_;
|
||||
arithmetic_run_ = "NULL";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_CHECK_RET_CODE(Init(context), "do arothmetic code failed!");
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void ArithmeticFP32Coder::ComputeInOutStrides() {
|
||||
int ArithmeticFP32Coder::ReSize(CoderContext *const context) {
|
||||
CalcMultiplesAndStrides(arithmetic_parameter_);
|
||||
if (arithmetic_parameter_->broadcasting_) {
|
||||
outside_ = 1;
|
||||
for (auto i = arithmetic_parameter_->ndim_ - 1; i >= 0; --i) {
|
||||
|
@ -256,20 +101,185 @@ void ArithmeticFP32Coder::ComputeInOutStrides() {
|
|||
}
|
||||
outside_ *= arithmetic_parameter_->out_shape_[i];
|
||||
}
|
||||
ComputeStrides(arithmetic_parameter_->in_shape0_, arithmetic_parameter_->in_strides0_,
|
||||
arithmetic_parameter_->ndim_);
|
||||
ComputeStrides(arithmetic_parameter_->in_shape1_, arithmetic_parameter_->in_strides1_,
|
||||
arithmetic_parameter_->ndim_);
|
||||
ComputeStrides(arithmetic_parameter_->out_shape_, arithmetic_parameter_->out_strides_,
|
||||
arithmetic_parameter_->ndim_);
|
||||
}
|
||||
int ret = RET_OK;
|
||||
if (!IsScalarClac() && !IsBatchScalarCalc() && !IsBiasCalc()) {
|
||||
ret = ConstTensorBroadCast(context);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ArithmeticFP32Coder::CheckDataType() {
|
||||
auto in0_dataType = input_tensor_->data_type();
|
||||
auto in1_dataType = filter_tensor_->data_type();
|
||||
if (in0_dataType != in1_dataType) {
|
||||
MS_LOG(ERROR) << "The dataTypes of input tensor0 and input tensor1 should be the same.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void ArithmeticFP32Coder::ChooseArithmeticFunc(bool is_opt) {
|
||||
if (input_tensor_->data_type() == kNumberTypeFloat32) {
|
||||
if (is_opt) {
|
||||
arithmetic_func_str_ = wrap_void(arithmetic_opt_run_);
|
||||
} else {
|
||||
arithmetic_func_str_ = wrap_void(arithmetic_run_);
|
||||
}
|
||||
} else if (input_tensor_->data_type() == kNumberTypeBool) {
|
||||
arithmetic_func_str_ = wrap_void(arithmetic_run_bool_);
|
||||
} else {
|
||||
if (is_opt) {
|
||||
arithmetic_func_str_ = wrap_void(arithmetic_opt_run_int_);
|
||||
} else {
|
||||
arithmetic_func_str_ = wrap_void(arithmetic_run_int_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ArithmeticFP32Coder::IsScalarClac() {
|
||||
return (arithmetic_parameter_->in_elements_num0_ == 1 || arithmetic_parameter_->in_elements_num1_ == 1) &&
|
||||
(!arithmetic_opt_run_.empty());
|
||||
}
|
||||
|
||||
bool ArithmeticFP32Coder::IsBatchScalarCalc() {
|
||||
if (arithmetic_opt_run_.empty()) {
|
||||
return false;
|
||||
}
|
||||
size_t break_axis = 0;
|
||||
for (size_t i = 0; i < arithmetic_parameter_->ndim_; i++) {
|
||||
if (arithmetic_parameter_->in_shape0_[i] != arithmetic_parameter_->in_shape1_[i]) {
|
||||
break_axis = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (break_axis < arithmetic_parameter_->ndim_) {
|
||||
for (size_t i = break_axis; i < arithmetic_parameter_->ndim_; i++) {
|
||||
if (arithmetic_parameter_->in_shape1_[i] != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
break_pos_ = break_axis;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ArithmeticFP32Coder::IsBiasCalc() {
|
||||
int last_shape0 = arithmetic_parameter_->in_shape0_[arithmetic_parameter_->ndim_ - 1];
|
||||
int last_shape1 = arithmetic_parameter_->in_shape1_[arithmetic_parameter_->ndim_ - 1];
|
||||
if (arithmetic_parameter_->in_elements_num0_ > arithmetic_parameter_->in_elements_num1_) {
|
||||
return arithmetic_parameter_->in_elements_num1_ == last_shape1 && last_shape0 == last_shape1;
|
||||
} else if (arithmetic_parameter_->in_elements_num0_ < arithmetic_parameter_->in_elements_num1_) {
|
||||
return arithmetic_parameter_->in_elements_num0_ == last_shape0 && last_shape0 == last_shape1;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ArithmeticFP32Coder::FreeConstTileBuff() {
|
||||
if (input0_broadcast_ && input0_ptr_ != nullptr) {
|
||||
input0_ptr_ = nullptr;
|
||||
input0_broadcast_ = false;
|
||||
}
|
||||
if (input1_broadcast_ && input1_ptr_ != nullptr) {
|
||||
input1_ptr_ = nullptr;
|
||||
input0_broadcast_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
int ArithmeticFP32Coder::ConstTensorBroadCast(CoderContext *const context) {
|
||||
// if const node need broadcast and all need-broadcast-node are const, broadcast in resize
|
||||
if (!arithmetic_parameter_->broadcasting_) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (output_tensor_->Size() < 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
// need broadcast both input
|
||||
if (arithmetic_parameter_->in_elements_num0_ != arithmetic_parameter_->out_elements_num_ &&
|
||||
arithmetic_parameter_->in_elements_num1_ != arithmetic_parameter_->out_elements_num_) {
|
||||
return RET_OK;
|
||||
}
|
||||
FreeConstTileBuff();
|
||||
NNaclFp32Serializer init_code;
|
||||
Collect(context,
|
||||
{
|
||||
"wrapper/fp32/arithmetic_fp32_wrapper.h",
|
||||
},
|
||||
{
|
||||
"arithmetic_fp32_wrapper.c",
|
||||
});
|
||||
if (input_tensor_->IsConst() &&
|
||||
arithmetic_parameter_->in_elements_num0_ != arithmetic_parameter_->out_elements_num_) {
|
||||
input0_ptr_ = reinterpret_cast<float *>(
|
||||
allocator_->Malloc(kNumberTypeFloat32, arithmetic_parameter_->out_elements_num_ * data_type_len_, kWorkspace));
|
||||
MS_CHECK_PTR(input0_ptr_);
|
||||
init_code.CodeArray("in_shape", arithmetic_parameter_->in_shape0_, arithmetic_parameter_->ndim_, true);
|
||||
init_code.CodeArray("in_stride", arithmetic_parameter_->in_strides0_, arithmetic_parameter_->ndim_, true);
|
||||
init_code.CodeArray("out_stride", arithmetic_parameter_->out_strides_, arithmetic_parameter_->ndim_, true);
|
||||
init_code.CodeArray("multiple", arithmetic_parameter_->multiples0_, arithmetic_parameter_->ndim_, true);
|
||||
init_code.CodeFunction("TileConstTensor", input_tensor_, input0_ptr_, arithmetic_parameter_->ndim_, "in_shape",
|
||||
"in_stride", "out_stride", "multiple");
|
||||
input0_broadcast_ = true;
|
||||
arithmetic_parameter_->in_elements_num0_ = arithmetic_parameter_->out_elements_num_;
|
||||
arithmetic_parameter_->broadcasting_ = false;
|
||||
}
|
||||
if (filter_tensor_->IsConst() &&
|
||||
arithmetic_parameter_->in_elements_num1_ != arithmetic_parameter_->out_elements_num_) {
|
||||
input1_ptr_ = reinterpret_cast<float *>(
|
||||
allocator_->Malloc(kNumberTypeFloat32, arithmetic_parameter_->out_elements_num_ * data_type_len_, kWorkspace));
|
||||
MS_CHECK_PTR(input1_ptr_);
|
||||
init_code.CodeArray("in_shape", arithmetic_parameter_->in_shape1_, arithmetic_parameter_->ndim_, true);
|
||||
init_code.CodeArray("in_stride", arithmetic_parameter_->in_strides1_, arithmetic_parameter_->ndim_, true);
|
||||
init_code.CodeArray("out_stride", arithmetic_parameter_->out_strides_, arithmetic_parameter_->ndim_, true);
|
||||
init_code.CodeArray("multiple", arithmetic_parameter_->multiples1_, arithmetic_parameter_->ndim_, true);
|
||||
init_code.CodeFunction("TileConstTensor", filter_tensor_, input1_ptr_, arithmetic_parameter_->ndim_, "in_shape",
|
||||
"in_stride", "out_stride", "multiple");
|
||||
input1_broadcast_ = true;
|
||||
arithmetic_parameter_->in_elements_num1_ = arithmetic_parameter_->out_elements_num_;
|
||||
arithmetic_parameter_->broadcasting_ = false;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticFP32Coder::Prepare(CoderContext *const context) {
|
||||
filter_tensor_ = input_tensors_.at(kWeightIndex);
|
||||
MS_CHECK_PTR(filter_tensor_);
|
||||
MS_CHECK_RET_CODE(CheckDataType(), "ArithmeticFP32Coder check datatype fail");
|
||||
MS_CHECK_PTR(parameter_);
|
||||
arithmetic_parameter_ = reinterpret_cast<ArithmeticParameter *>(parameter_);
|
||||
auto primitive_type = arithmetic_parameter_->op_parameter_.type_;
|
||||
if (primitive_type == schema::PrimitiveType_Eltwise) {
|
||||
switch (arithmetic_parameter_->eltwise_mode_) {
|
||||
case schema::EltwiseMode_PROD:
|
||||
primitive_type = schema::PrimitiveType_MulFusion;
|
||||
break;
|
||||
case schema::EltwiseMode_SUM:
|
||||
primitive_type = schema::PrimitiveType_AddFusion;
|
||||
break;
|
||||
case schema::EltwiseMode_MAXIMUM:
|
||||
primitive_type = schema::PrimitiveType_Maximum;
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Eltwise mode not support, mode:" << arithmetic_parameter_->eltwise_mode_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
InitRunFunction(primitive_type);
|
||||
MS_CHECK_RET_CODE(ReSize(context), "do arithmetic ReSize fail!");
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void ArithmeticFP32Coder::CollectFilesForFunc(CoderContext *const context) {
|
||||
/**
|
||||
* for nnacl's operator combine all arithmetic to nnalc/arithmetic.c
|
||||
* this solution is not suitable for micro, for the size of package.
|
||||
* */
|
||||
// collect wrapper files
|
||||
Collect(context,
|
||||
{
|
||||
"wrapper/fp32/arithmetic_fp32_wrapper.h",
|
||||
},
|
||||
{
|
||||
"arithmetic_fp32_wrapper.c",
|
||||
});
|
||||
// for nnacl's operator combine all arithmetic to nnalc/arithmetic.c
|
||||
// this solution is not suitable for micro, for the size of package.
|
||||
if (arithmetic_opt_run_ == "ElementOptSub" || arithmetic_run_ == "ElementSub") {
|
||||
Collect(context,
|
||||
{
|
||||
|
@ -324,57 +334,148 @@ void ArithmeticFP32Coder::CollectFilesForFunc(CoderContext *const context) {
|
|||
}
|
||||
}
|
||||
|
||||
int ArithmeticFP32Coder::ExecuteCode(const std::string &input0, const std::string &input1, const std::string &output,
|
||||
int size, bool is_opt, CoderContext *const context,
|
||||
NNaclFp32Serializer *const code) {
|
||||
if (arithmetic_func_str_.empty()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
code->CodeStruct("arithmetic_parameter", *arithmetic_parameter_);
|
||||
code->CodeFunction("ArithmeticExecute", input0, input1, output, size, is_opt, arithmetic_func_type_,
|
||||
arithmetic_func_str_, "&arithmetic_parameter");
|
||||
context->AppendCode(code->str());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticFP32Coder::BatchScalarCalc(int task_id, CoderContext *const context, NNaclFp32Serializer *const code) {
|
||||
if (break_pos_ < 1) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
int batch = arithmetic_parameter_->out_elements_num_ / arithmetic_parameter_->out_strides_[break_pos_ - 1];
|
||||
int batch_per_thread = UP_DIV(batch, thread_num_);
|
||||
|
||||
int start_batch = batch_per_thread * task_id;
|
||||
int end_batch = MSMIN(start_batch + batch_per_thread, batch);
|
||||
int batch_size = end_batch - start_batch;
|
||||
|
||||
int stride0 = arithmetic_parameter_->in_strides0_[break_pos_ - 1] * data_type_len_;
|
||||
int stride1 = arithmetic_parameter_->in_strides1_[break_pos_ - 1] * data_type_len_;
|
||||
int out_stride = arithmetic_parameter_->out_strides_[break_pos_ - 1] * data_type_len_;
|
||||
|
||||
int offset0 = stride0 * start_batch;
|
||||
int offset1 = stride1 * start_batch;
|
||||
int out_offset = out_stride * start_batch;
|
||||
|
||||
arithmetic_wrapper_info_ = {offset0, stride0, offset1, stride1, out_offset, out_stride, arithmetic_func_type_};
|
||||
|
||||
code->CodeStruct("arithmetic_wrapper_info", arithmetic_wrapper_info_);
|
||||
code->CodeStruct("arithmetic_parameter", *arithmetic_parameter_);
|
||||
code->CodeFunction("BatchScalarCalc", wrap_uint8(input0_ptr_str_), wrap_uint8(input1_ptr_str_),
|
||||
wrap_uint8(output_ptr_str_), batch_size, arithmetic_parameter_->out_strides_[break_pos_ - 1], true,
|
||||
arithmetic_func_str_, "&arithmetic_wrapper_info", "&arithmetic_parameter");
|
||||
context->AppendCode(code->str());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticFP32Coder::BiasCalc(int task_id, CoderContext *const context, NNaclFp32Serializer *const code) {
|
||||
int last_shape = arithmetic_parameter_->out_shape_[arithmetic_parameter_->ndim_ - 1];
|
||||
int batch = arithmetic_parameter_->out_elements_num_ / last_shape;
|
||||
int batch_per_thread = UP_DIV(batch, thread_num_);
|
||||
|
||||
int start_batch = batch_per_thread * task_id;
|
||||
int end_batch = MSMIN(start_batch + batch_per_thread, batch);
|
||||
int batch_size = end_batch - start_batch;
|
||||
|
||||
int stride = last_shape * data_type_len_;
|
||||
int offset = stride * start_batch;
|
||||
code->CodeStruct("arithmetic_parameter", *arithmetic_parameter_);
|
||||
if (arithmetic_parameter_->in_elements_num0_ > arithmetic_parameter_->in_elements_num1_) {
|
||||
arithmetic_wrapper_info_ = {offset, stride, 0, 0, offset, stride, arithmetic_func_type_};
|
||||
code->CodeStruct("arithmetic_wrapper_info", arithmetic_wrapper_info_);
|
||||
code->CodeFunction("BatchScalarCalc", wrap_uint8(input0_ptr_str_), wrap_uint8(input1_ptr_str_),
|
||||
wrap_uint8(output_ptr_str_), batch_size, last_shape, false, arithmetic_func_str_,
|
||||
"&arithmetic_wrapper_info", "&arithmetic_parameter");
|
||||
} else {
|
||||
arithmetic_wrapper_info_ = {0, 0, offset, stride, offset, stride, arithmetic_func_type_};
|
||||
code->CodeStruct("arithmetic_wrapper_info", arithmetic_wrapper_info_);
|
||||
code->CodeFunction("BatchScalarCalc", wrap_uint8(input0_ptr_str_), wrap_uint8(input1_ptr_str_),
|
||||
wrap_uint8(output_ptr_str_), batch_size, last_shape, false, arithmetic_func_str_,
|
||||
"&arithmetic_wrapper_info", "&arithmetic_parameter");
|
||||
}
|
||||
context->AppendCode(code->str());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticFP32Coder::BroadcastRun(const std::string &input0, const std::string &input1, const std::string &output,
|
||||
int dim, int out_count, int out_thread_stride, CoderContext *const context,
|
||||
NNaclFp32Serializer *const code) {
|
||||
code->CodeStruct("arithmetic_parameter", *arithmetic_parameter_);
|
||||
code->CodeFunction("BroadcastRun", wrap_uint8(input0_ptr_str_), wrap_uint8(input1_ptr_str_),
|
||||
wrap_uint8(output_ptr_str_), dim, out_count, out_thread_stride, break_pos_, data_type_len_,
|
||||
arithmetic_func_type_, arithmetic_func_str_, "&arithmetic_parameter");
|
||||
context->AppendCode(code->str());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticFP32Coder::DoCode(CoderContext *const context) {
|
||||
ComputeInOutStrides();
|
||||
int element_num = output_tensor_->ElementsNum();
|
||||
MS_CHECK_TRUE(thread_num_ > 0, "thread_num_ is less than zero");
|
||||
int stride = UP_DIV(element_num, thread_num_);
|
||||
int count = MSMIN(stride, element_num - stride * kDefaultTaskId);
|
||||
MS_CHECK_TRUE(!arithmetic_run_.empty(), "arithmetic_run function is nullptr!");
|
||||
if (count <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
int offset = stride * kDefaultTaskId * data_type_len_;
|
||||
input0_ptr_str_ = allocator_->GetRuntimeAddr(input0_ptr_);
|
||||
input1_ptr_str_ = allocator_->GetRuntimeAddr(input1_ptr_);
|
||||
if (!input0_broadcast_) {
|
||||
input0_ptr_str_ = allocator_->GetRuntimeAddr(input_tensor_, true);
|
||||
input1_ptr_str_ = allocator_->GetRuntimeAddr(filter_tensor_);
|
||||
}
|
||||
if (!input1_broadcast_) {
|
||||
input0_ptr_str_ = allocator_->GetRuntimeAddr(input_tensor_);
|
||||
input1_ptr_str_ = allocator_->GetRuntimeAddr(filter_tensor_, true);
|
||||
}
|
||||
output_ptr_str_ = allocator_->GetRuntimeAddr(output_tensor_);
|
||||
NNaclFp32Serializer code;
|
||||
CollectFilesForFunc(context);
|
||||
if (arithmetic_parameter_->broadcasting_) {
|
||||
stride = UP_DIV(outside_, thread_num_);
|
||||
out_count_ = MSMIN(stride, outside_ - stride * kDefaultTaskId);
|
||||
out_thread_stride_ = stride * kDefaultTaskId;
|
||||
std::string input0_str = allocator_->GetRuntimeAddr(input_tensor_);
|
||||
std::string input1_str = allocator_->GetRuntimeAddr(filter_tensor_);
|
||||
std::string output_str = allocator_->GetRuntimeAddr(output_tensor_);
|
||||
MS_CHECK_RET_CODE(BroadcastRun(input0_str, input1_str, output_str, 0, out_count_, out_thread_stride_, &code),
|
||||
"do broad cast code failed!");
|
||||
} else if (!arithmetic_opt_run_.empty()) {
|
||||
code.CodeStruct("arithmetic_parameter", *arithmetic_parameter_);
|
||||
if (IsScalarClac()) {
|
||||
ChooseArithmeticFunc(true);
|
||||
if (arithmetic_parameter_->in_elements_num0_ == 1) {
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
code.CodeFunction(arithmetic_opt_run_, input_tensor_, filter_tensor_, output_tensor_, count,
|
||||
"&arithmetic_parameter");
|
||||
} else {
|
||||
code.CodeFunction(arithmetic_opt_run_int_, input_tensor_, filter_tensor_, output_tensor_, count,
|
||||
"&arithmetic_parameter");
|
||||
}
|
||||
return ExecuteCode(wrap_uint8(input0_ptr_str_), wrap_offset(wrap_uint8(input1_ptr_str_), offset),
|
||||
wrap_offset(wrap_uint8(output_ptr_str_), offset), count, true, context, &code);
|
||||
} else if (arithmetic_parameter_->in_elements_num1_ == 1) {
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
code.CodeFunction(arithmetic_opt_run_, input_tensor_, filter_tensor_, output_tensor_, count,
|
||||
"&arithmetic_parameter");
|
||||
} else {
|
||||
code.CodeFunction(arithmetic_opt_run_int_, input_tensor_, filter_tensor_, output_tensor_, count,
|
||||
"&arithmetic_parameter");
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "arithmetic opt code run: at least one of inputs is scalar";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
code.CodeFunction(arithmetic_run_, input_tensor_, filter_tensor_, output_tensor_, count);
|
||||
} else {
|
||||
code.CodeFunction(arithmetic_run_int_, input_tensor_, filter_tensor_, output_tensor_, count);
|
||||
return ExecuteCode(wrap_offset(wrap_uint8(input0_ptr_str_), offset), wrap_void(input1_ptr_str_),
|
||||
wrap_offset(wrap_uint8(output_ptr_str_), offset), count, true, context, &code);
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "ArithmeticFP32Code has been called";
|
||||
context->AppendCode(code.str());
|
||||
|
||||
return RET_OK;
|
||||
// run opt function, every batch one of input is scalar
|
||||
if (IsBatchScalarCalc()) {
|
||||
ChooseArithmeticFunc(true);
|
||||
return BatchScalarCalc(kDefaultTaskId, context, &code);
|
||||
}
|
||||
|
||||
// each batch is eltwise calculation
|
||||
if (IsBiasCalc()) {
|
||||
ChooseArithmeticFunc(false);
|
||||
return BiasCalc(kDefaultTaskId, context, &code);
|
||||
}
|
||||
|
||||
// need broadcast in runtime
|
||||
if (arithmetic_parameter_->broadcasting_) {
|
||||
ChooseArithmeticFunc(false);
|
||||
stride = UP_DIV(outside_, thread_num_);
|
||||
int out_count = MSMIN(stride, outside_ - stride * kDefaultTaskId);
|
||||
if (out_count <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
return BroadcastRun(input0_ptr_str_, input1_ptr_str_, output_ptr_str_, 0, out_count, stride * kDefaultTaskId,
|
||||
context, &code);
|
||||
}
|
||||
// all elements eltwise calculation
|
||||
ChooseArithmeticFunc(false);
|
||||
return ExecuteCode(wrap_offset(wrap_uint8(input0_ptr_str_), offset), wrap_offset(wrap_uint8(input1_ptr_str_), offset),
|
||||
wrap_offset(wrap_uint8(output_ptr_str_), offset), count, false, context, &code);
|
||||
}
|
||||
|
||||
REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_AddFusion, CPUOpCoderCreator<ArithmeticFP32Coder>)
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include "coder/opcoders/op_coder.h"
|
||||
#include "nnacl/fp32/arithmetic_fp32.h"
|
||||
#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
|
||||
#define DEFAULT_ARITHMETIC_NDIMS 10
|
||||
#include "wrapper/fp32/arithmetic_fp32_wrapper.h"
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
|
||||
using mindspore::schema::PrimitiveType_AddFusion;
|
||||
|
@ -65,7 +65,19 @@ using mindspore::schema::PrimitiveType_Eltwise;
|
|||
|
||||
using mindspore::schema::PrimitiveType_Minimum;
|
||||
|
||||
using mindspore::schema::PrimitiveType_Mod;
|
||||
|
||||
class ArithmeticFP32Coder final : public OperatorCoder {
|
||||
typedef struct {
|
||||
int primitive_type_;
|
||||
int activation_type_;
|
||||
std::string func_;
|
||||
std::string int_func_;
|
||||
std::string bool_func_;
|
||||
std::string opt_func_;
|
||||
std::string opt_int_func_;
|
||||
} ARITHMETIC_FUNC_INFO_FP32;
|
||||
|
||||
public:
|
||||
ArithmeticFP32Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const Model::Node *node, size_t node_index, Target target)
|
||||
|
@ -78,15 +90,39 @@ class ArithmeticFP32Coder final : public OperatorCoder {
|
|||
int DoCode(CoderContext *const context) override;
|
||||
|
||||
private:
|
||||
int Init(CoderContext *const context);
|
||||
int ReSize(CoderContext *const context);
|
||||
|
||||
int ExecuteCode(const std::string &input0, const std::string &input1, const std::string &output, int size,
|
||||
bool is_opt, CoderContext *const context, NNaclFp32Serializer *const code);
|
||||
|
||||
void InitRunFunction(int primitive_type);
|
||||
|
||||
int CheckDataType();
|
||||
|
||||
void ChooseArithmeticFunc(bool is_opt);
|
||||
|
||||
bool IsScalarClac();
|
||||
|
||||
bool IsBatchScalarCalc();
|
||||
|
||||
bool IsBiasCalc();
|
||||
|
||||
void FreeConstTileBuff();
|
||||
|
||||
int ConstTensorBroadCast(CoderContext *const context);
|
||||
|
||||
void ComputeInOutStrides();
|
||||
|
||||
int BroadcastRun(const std::string &input0, const std::string &input1, const std::string &output, int dim,
|
||||
int out_count, int out_thread_stride, NNaclFp32Serializer *const code);
|
||||
int out_count, int out_thread_stride, CoderContext *const context, NNaclFp32Serializer *const code);
|
||||
|
||||
int BatchScalarCalc(int task_id, CoderContext *const context, NNaclFp32Serializer *const code);
|
||||
|
||||
int BiasCalc(int task_id, CoderContext *const context, NNaclFp32Serializer *const code);
|
||||
|
||||
void CollectFilesForFunc(CoderContext *const context);
|
||||
|
||||
private:
|
||||
int break_pos_{0};
|
||||
|
||||
int outside_{0};
|
||||
|
@ -95,10 +131,32 @@ class ArithmeticFP32Coder final : public OperatorCoder {
|
|||
|
||||
int out_count_{0};
|
||||
|
||||
int data_type_len_{0};
|
||||
|
||||
bool input0_broadcast_{false};
|
||||
|
||||
bool input1_broadcast_{false};
|
||||
|
||||
float *input0_ptr_{nullptr};
|
||||
|
||||
float *input1_ptr_{nullptr};
|
||||
|
||||
float *output_ptr_{nullptr};
|
||||
|
||||
ArithmeticParameter *arithmetic_parameter_{nullptr};
|
||||
|
||||
Tensor *filter_tensor_{nullptr};
|
||||
|
||||
ArithmeticFuncType arithmetic_func_type_{kArithmeticFuncUnknow};
|
||||
|
||||
ArithmeticWrapperInfo arithmetic_wrapper_info_{};
|
||||
|
||||
std::string input0_ptr_str_;
|
||||
|
||||
std::string input1_ptr_str_;
|
||||
|
||||
std::string output_ptr_str_;
|
||||
|
||||
std::string arithmetic_run_;
|
||||
|
||||
std::string arithmetic_run_int_;
|
||||
|
@ -107,7 +165,9 @@ class ArithmeticFP32Coder final : public OperatorCoder {
|
|||
|
||||
std::string arithmetic_opt_run_int_;
|
||||
|
||||
LiteDataType data_type_{kDataTypeFloat};
|
||||
std::string arithmetic_run_bool_;
|
||||
|
||||
std::string arithmetic_func_str_;
|
||||
};
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_FP32_ARITHMETIC_FP32_CODER_H_
|
||||
|
|
|
@ -28,7 +28,6 @@
|
|||
using mindspore::schema::PrimitiveType_MatMul;
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
|
||||
int MatMulFP32BaseCoder::ReSize() {
|
||||
ResizeParameter();
|
||||
thread_count_ = MSMIN(thread_num_, UP_DIV(params_->col_align_, col_tile_));
|
||||
|
@ -45,8 +44,12 @@ int MatMulFP32BaseCoder::ReSize() {
|
|||
|
||||
int MatMulFP32BaseCoder::InitBiasData() {
|
||||
if (input_tensors_.size() == 3) {
|
||||
int max_bias_data = UP_ROUND(bias_tensor_->ElementsNum(), C16NUM);
|
||||
int max_bias_data = params_->col_align_;
|
||||
bias_pack_ptr_size_ = static_cast<size_t>(max_bias_data * sizeof(float));
|
||||
if (bias_tensor_->ElementsNum() == 1) {
|
||||
is_bias_broadcast_ = true;
|
||||
}
|
||||
ori_bias_pack_ptr_size_ = bias_tensor_->ElementsNum() * sizeof(float);
|
||||
bias_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight));
|
||||
MS_CHECK_PTR(bias_ptr_);
|
||||
}
|
||||
|
@ -123,8 +126,7 @@ int MatMulFP32BaseCoder::Init() {
|
|||
|
||||
int MatMulFP32BaseCoder::Prepare(CoderContext *const context) { return RET_OK; }
|
||||
|
||||
int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
|
||||
// generate code .h .c
|
||||
int MatMulFP32BaseCoder::CollectFilesForTarget(CoderContext *const context) {
|
||||
Collect(context,
|
||||
{
|
||||
"nnacl/fp32/matmul_fp32.h",
|
||||
|
@ -158,6 +160,11 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
|
|||
"dequant_int8_to_fp32_wrapper.c",
|
||||
});
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
|
||||
CollectFilesForTarget(context);
|
||||
NNaclFp32Serializer code;
|
||||
NNaclFp32Serializer init_code;
|
||||
code.CodeStruct("mat_mul_parameter", *params_);
|
||||
|
@ -165,7 +172,17 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
|
|||
// do bias packing to init
|
||||
if (bias_ptr_) {
|
||||
init_code.CodeMallocExpression(bias_ptr_, bias_pack_ptr_size_);
|
||||
init_code.CodeFunction("memcpy", bias_ptr_, bias_tensor_, bias_pack_ptr_size_);
|
||||
init_code.CodeFunction("memset", bias_ptr_, 0, bias_pack_ptr_size_);
|
||||
int max_bias_data = params_->col_align_;
|
||||
if (is_bias_broadcast_) {
|
||||
float broad_cast_data = (reinterpret_cast<float *>(bias_tensor_->data_c()))[0];
|
||||
std::string bias_ptr_str = "((float *)(" + allocator_->GetRuntimeAddr(bias_ptr_) + "))";
|
||||
init_code << "\tfor (int i = 0; i < " << max_bias_data << "; ++i) {\n";
|
||||
init_code << "\t\t" << bias_ptr_str << "[i] = " << broad_cast_data << ";\n";
|
||||
init_code << "\t}\n";
|
||||
} else {
|
||||
init_code.CodeFunction("memcpy", bias_ptr_, bias_tensor_, ori_bias_pack_ptr_size_);
|
||||
}
|
||||
}
|
||||
|
||||
// Get Tensor Pointer
|
||||
|
@ -179,6 +196,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
|
|||
if (!params_->a_const_) {
|
||||
code.CodeFunction("InitMatrixA", input_tensor_, a_pack_ptr_, "&mat_mul_parameter", vec_matmul_);
|
||||
init_code.CodeMallocExpression(b_pack_ptr_, b_pack_ptr_size_);
|
||||
init_code.CodeFunction("memset", b_pack_ptr_, 0, b_pack_ptr_size_);
|
||||
std::string b_src_str = b_str;
|
||||
if (de_quant_flag_) {
|
||||
// reuse to b_pack_str
|
||||
|
@ -191,6 +209,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
|
|||
}
|
||||
if (!params_->b_const_) {
|
||||
init_code.CodeMallocExpression(a_pack_str, a_pack_ptr_size_);
|
||||
init_code.CodeFunction("memset", a_pack_ptr_, 0, a_pack_ptr_size_);
|
||||
std::string a_src_str = a_str;
|
||||
if (de_quant_flag_) {
|
||||
// reuse to a_pack_str
|
||||
|
@ -228,7 +247,6 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
|
|||
params_->deep_, params_->row_, cur_oc, params_->col_, "OutType_Nhwc");
|
||||
}
|
||||
code << "\t\t}\n";
|
||||
|
||||
context->AppendCode(code.str());
|
||||
context->AppendInitCode(init_code.str());
|
||||
return RET_OK;
|
||||
|
|
|
@ -43,6 +43,7 @@ class MatMulFP32BaseCoder : public OperatorCoder {
|
|||
int InitBufferB();
|
||||
int InitMatrixA(const float *src_ptr);
|
||||
int InitMatrixB(const float *src_ptr);
|
||||
int CollectFilesForTarget(CoderContext *const context);
|
||||
|
||||
protected:
|
||||
virtual int Init();
|
||||
|
@ -64,8 +65,10 @@ class MatMulFP32BaseCoder : public OperatorCoder {
|
|||
int thread_stride_{0};
|
||||
int thread_count_{0};
|
||||
size_t bias_pack_ptr_size_{0};
|
||||
size_t ori_bias_pack_ptr_size_{0};
|
||||
size_t a_pack_ptr_size_{0};
|
||||
size_t b_pack_ptr_size_{0};
|
||||
bool is_bias_broadcast_{false};
|
||||
};
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_FP32_MATMUL_FP32_BASE_CODER_H_
|
||||
|
|
|
@ -74,13 +74,13 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const ConvParamete
|
|||
}
|
||||
|
||||
void NNaclFp32Serializer::CodeStruct(const std::string &name, const MatMulParameter &mat_mul_parameter) {
|
||||
CodeBaseStruct("MatMulParameter", name, mat_mul_parameter.op_parameter_, mat_mul_parameter.has_bias_,
|
||||
mat_mul_parameter.row_, mat_mul_parameter.col_, mat_mul_parameter.row_4_, mat_mul_parameter.row_6_,
|
||||
mat_mul_parameter.row_12_, mat_mul_parameter.row_16_, mat_mul_parameter.row_align_,
|
||||
mat_mul_parameter.col_4_, mat_mul_parameter.col_8_, mat_mul_parameter.col_align_,
|
||||
mat_mul_parameter.deep_, mat_mul_parameter.deep_4_, mat_mul_parameter.deep_16_,
|
||||
mat_mul_parameter.batch, mat_mul_parameter.a_transpose_, mat_mul_parameter.b_transpose_,
|
||||
mat_mul_parameter.a_const_, mat_mul_parameter.b_const_, mat_mul_parameter.act_type_);
|
||||
CodeBaseStruct(
|
||||
"MatMulParameter", name, mat_mul_parameter.op_parameter_, mat_mul_parameter.has_bias_, mat_mul_parameter.row_,
|
||||
mat_mul_parameter.col_, mat_mul_parameter.row_4_, mat_mul_parameter.row_6_, mat_mul_parameter.row_12_,
|
||||
mat_mul_parameter.row_16_, mat_mul_parameter.row_align_, mat_mul_parameter.col_4_, mat_mul_parameter.col_8_,
|
||||
mat_mul_parameter.col_align_, mat_mul_parameter.deep_, mat_mul_parameter.deep_4_, mat_mul_parameter.deep_16_,
|
||||
mat_mul_parameter.batch, mat_mul_parameter.a_transpose_, mat_mul_parameter.b_transpose_, mat_mul_parameter.a_const_,
|
||||
mat_mul_parameter.b_const_, mat_mul_parameter.act_type_, mat_mul_parameter.use_axis_, mat_mul_parameter.axis_);
|
||||
}
|
||||
|
||||
void NNaclFp32Serializer::CodeStruct(const std::string &name, const ScaleParameter &scale_parameter) {
|
||||
|
@ -143,4 +143,11 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const StridedSlice
|
|||
strided_slice_parameter.begins_mask_, strided_slice_parameter.ellipsisMask_,
|
||||
strided_slice_parameter.newAxisMask_, strided_slice_parameter.shrinkAxisMask_);
|
||||
}
|
||||
|
||||
void NNaclFp32Serializer::CodeStruct(const std::string &name, const ArithmeticWrapperInfo &arithmetic_wrapper_info) {
|
||||
CodeBaseStruct("ArithmeticWrapperInfo", name, arithmetic_wrapper_info.offset0_, arithmetic_wrapper_info.stride0_,
|
||||
arithmetic_wrapper_info.offset1_, arithmetic_wrapper_info.stride1_,
|
||||
arithmetic_wrapper_info.out_offset_, arithmetic_wrapper_info.out_stride_,
|
||||
arithmetic_wrapper_info.arithmetic_func_type_);
|
||||
}
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
#include "wrapper/fp32/dequant_int8_to_fp32_wrapper.h"
|
||||
#include "nnacl/fp32/exp_fp32.h"
|
||||
#include "nnacl/fp32/strided_slice_fp32.h"
|
||||
#include "wrapper/fp32/arithmetic_fp32_wrapper.h"
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
|
||||
class NNaclFp32Serializer : public Serializer {
|
||||
|
@ -55,6 +56,7 @@ class NNaclFp32Serializer : public Serializer {
|
|||
void CodeStruct(const std::string &name, const SpliceParameter &splice_parameter);
|
||||
void CodeStruct(const std::string &name, const ExpParameter &exp_parameter);
|
||||
void CodeStruct(const std::string &name, const StridedSliceParameter &strided_slice_parameter);
|
||||
void CodeStruct(const std::string &name, const ArithmeticWrapperInfo &arithmetic_wrapper_info);
|
||||
};
|
||||
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "wrapper/fp32/arithmetic_fp32_wrapper.h"
|
||||
void TileConstTensor(const float *in_data, float *out_data, size_t ndim, const int *in_shape, const int *in_strides,
|
||||
const int *out_strides, const int *multiple) {
|
||||
TileOneDimensionFp32(in_data, out_data, 0, ndim, in_shape, in_strides, out_strides, multiple);
|
||||
}
|
||||
|
||||
void ArithmeticExecute(const void *input0, const void *input1, void *output, int size, bool is_opt,
|
||||
ArithmeticFuncType arithmetic_func_type, const void *arithmetic_func,
|
||||
const ArithmeticParameter *param) {
|
||||
if (arithmetic_func_type == kArithmeticFuncFloat) {
|
||||
if (is_opt) {
|
||||
ArithmeticOptRun arithmetic_opt_run = (ArithmeticOptRun)(arithmetic_func);
|
||||
arithmetic_opt_run((const float *)(input0), (const float *)(input1), (float *)(output), size, param);
|
||||
} else {
|
||||
ArithmeticRun arithmetic_run = (ArithmeticRun)(arithmetic_func);
|
||||
arithmetic_run((const float *)(input0), (const float *)(input1), (float *)(output), size);
|
||||
}
|
||||
} else if (arithmetic_func_type == kArithmeticFuncBool) {
|
||||
ArithmeticBoolRun arithmetic_run_bool = (ArithmeticBoolRun)(arithmetic_func);
|
||||
arithmetic_run_bool((const bool *)(input0), (const bool *)(input1), (bool *)(output), size);
|
||||
} else if (arithmetic_func_type == kArithmeticFuncInt) {
|
||||
if (is_opt) {
|
||||
ArithmeticOptIntRun arithmetic_opt_run_int = (ArithmeticOptIntRun)(arithmetic_func);
|
||||
arithmetic_opt_run_int((const int *)(input0), (const int *)(input1), (int *)(output), size, param);
|
||||
} else {
|
||||
ArithmeticIntRun arithmetic_run_int = (ArithmeticIntRun)(arithmetic_func);
|
||||
arithmetic_run_int((const int *)(input0), (const int *)(input1), (int *)(output), size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BatchScalarCalc(const void *input0, const void *input1, void *output, int batch_size, int size, bool is_opt,
|
||||
const void *arithmetic_func, const ArithmeticWrapperInfo *wrapper_info,
|
||||
const ArithmeticParameter *param) {
|
||||
int offset0 = wrapper_info->offset0_;
|
||||
int offset1 = wrapper_info->offset1_;
|
||||
int out_offset = wrapper_info->out_offset_;
|
||||
int stride0 = wrapper_info->stride0_;
|
||||
int stride1 = wrapper_info->stride1_;
|
||||
int out_stride = wrapper_info->out_stride_;
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
ArithmeticExecute((const uint8_t *)(input0) + offset0, (const uint8_t *)(input1) + offset1,
|
||||
(uint8_t *)(output) + out_offset, size, is_opt, wrapper_info->arithmetic_func_type_,
|
||||
arithmetic_func, param);
|
||||
offset0 += stride0;
|
||||
offset1 += stride1;
|
||||
out_offset += out_stride;
|
||||
}
|
||||
}
|
||||
|
||||
void BroadcastRun(const void *input0, const void *input1, void *output, int dim, int out_count, int out_thread_stride,
|
||||
int break_pos, int data_type_len, ArithmeticFuncType arithmetic_func_type,
|
||||
const void *arithmetic_func, const ArithmeticParameter *param) {
|
||||
if (dim > break_pos) {
|
||||
int offset = out_thread_stride * data_type_len;
|
||||
ArithmeticExecute((const uint8_t *)(input0) + offset, (const uint8_t *)(input1) + offset,
|
||||
(uint8_t *)(output) + offset, out_count, false, arithmetic_func_type, arithmetic_func, param);
|
||||
}
|
||||
int offset_size[] = {param->in_strides0_[dim] * data_type_len, param->in_strides1_[dim] * data_type_len,
|
||||
param->out_strides_[dim] * data_type_len};
|
||||
for (int i = 0; i < param->out_shape_[dim]; ++i) {
|
||||
int pos0_ = param->in_shape0_[dim] == 1 ? 0 : i;
|
||||
int pos1_ = param->in_shape1_[dim] == 1 ? 0 : i;
|
||||
BroadcastRun((const uint8_t *)(input0) + pos0_ * offset_size[0], (const uint8_t *)(input1) + pos1_ * offset_size[1],
|
||||
(uint8_t *)(output) + i * offset_size[2], dim + 1, out_count, out_thread_stride, break_pos,
|
||||
data_type_len, arithmetic_func_type, arithmetic_func, param);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_ARITHMETIC_FP32_WRAPPER_H_
|
||||
#define MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_ARITHMETIC_FP32_WRAPPER_H_
|
||||
#include "nnacl/fp32/arithmetic_fp32.h"
|
||||
#include <string.h>
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
typedef enum ArithmeticFuncType {
|
||||
kArithmeticFuncFloat = 0,
|
||||
kArithmeticFuncBool = 1,
|
||||
kArithmeticFuncInt = 2,
|
||||
kArithmeticFuncUnknow = 3,
|
||||
} ArithmeticFuncType;
|
||||
|
||||
typedef struct ArithmeticWrapperInfo {
|
||||
int offset0_;
|
||||
int stride0_;
|
||||
int offset1_;
|
||||
int stride1_;
|
||||
int out_offset_;
|
||||
int out_stride_;
|
||||
ArithmeticFuncType arithmetic_func_type_;
|
||||
} ArithmeticWrapperInfo;
|
||||
|
||||
typedef int (*ArithmeticRun)(const float *input0, const float *input1, float *output, const int element_size);
|
||||
typedef int (*ArithmeticOptRun)(const float *input0, const float *input1, float *output, const int element_size,
|
||||
const ArithmeticParameter *param);
|
||||
typedef int (*ArithmeticIntRun)(const int *input0, const int *input1, int *output, const int element_size);
|
||||
typedef int (*ArithmeticOptIntRun)(const int *input0, const int *input1, int *output, const int element_size,
|
||||
const ArithmeticParameter *param);
|
||||
typedef int (*ArithmeticBoolRun)(const bool *input0, const bool *input1, bool *output, const int element_size);
|
||||
|
||||
void ArithmeticExecute(const void *input0, const void *input1, void *output, int size, bool is_opt,
|
||||
ArithmeticFuncType arithmetic_func_type, const void *arithmetic_func,
|
||||
const ArithmeticParameter *param);
|
||||
|
||||
void TileConstTensor(const float *in_data, float *out_data, size_t ndim, const int *in_shape, const int *in_strides,
|
||||
const int *out_strides, const int *multiple);
|
||||
|
||||
void BatchScalarCalc(const void *input0, const void *input1, void *output, int batch_size, int size, bool is_opt,
|
||||
const void *arithmetic_func, const ArithmeticWrapperInfo *wrapper_info,
|
||||
const ArithmeticParameter *param);
|
||||
|
||||
void BroadcastRun(const void *input0, const void *input1, void *output, int dim, int out_count, int out_thread_stride,
|
||||
int break_pos, int data_type_len, ArithmeticFuncType arithmetic_func_type,
|
||||
const void *arithmetic_func, const ArithmeticParameter *param);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_ARITHMETIC_FP32_WRAPPER_H_
|
|
@ -11,6 +11,7 @@ include(${TOP_DIR}/cmake/external_libs/gtest.cmake)
|
|||
include(${MICRO_DIR}/cmake/file_list.cmake)
|
||||
include(${MICRO_DIR}/cmake/package_wrapper.cmake)
|
||||
|
||||
include_directories(${NNACL_DIR}/../)
|
||||
include_directories(${TOP_DIR})
|
||||
include_directories(${TOP_DIR}/mindspore/core/)
|
||||
include_directories(${LITE_DIR})
|
||||
|
|
|
@ -90,17 +90,52 @@ int MatmulFp32BaseCPUKernel::InitBufferB() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::CalBroadCastBiasDataElements() {
|
||||
lite::Tensor *bias_tensor = in_tensors_.at(2);
|
||||
int max_bias_data = UP_ROUND(bias_tensor->ElementsNum(), C16NUM);
|
||||
if (!params_->b_const_) {
|
||||
MS_LOG(WARNING) << "matmul do not support broadcast bias data";
|
||||
} else {
|
||||
lite::Tensor *const_tensor = in_tensors_.at(1);
|
||||
size_t shape_size = const_tensor->shape().size();
|
||||
if (shape_size < kBiasIndex) {
|
||||
return max_bias_data;
|
||||
}
|
||||
if (params_->b_transpose_) {
|
||||
max_bias_data = UP_ROUND(const_tensor->shape()[shape_size - kBiasIndex], C16NUM);
|
||||
} else {
|
||||
max_bias_data = UP_ROUND(const_tensor->shape()[shape_size - kWeightIndex], C16NUM);
|
||||
}
|
||||
}
|
||||
return max_bias_data;
|
||||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::InitBiasData() {
|
||||
if (in_tensors_.size() == 3) {
|
||||
auto bias_tensor = in_tensors_[2];
|
||||
int max_bias_data = UP_ROUND(bias_tensor->ElementsNum(), C16NUM);
|
||||
bias_ptr_ = reinterpret_cast<float *>(malloc(max_bias_data * sizeof(float)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias_ptr_ failed";
|
||||
return RET_ERROR;
|
||||
// whether to broadcast bias data
|
||||
if (bias_tensor->ElementsNum() == 1) {
|
||||
max_bias_data = CalBroadCastBiasDataElements();
|
||||
bias_ptr_ = reinterpret_cast<float *>(malloc(max_bias_data * sizeof(float)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias_ptr_ failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
float broadcast_data = (reinterpret_cast<float *>(bias_tensor->data_c()))[0];
|
||||
// broadcast bias data
|
||||
for (int i = 0; i < max_bias_data; ++i) {
|
||||
bias_ptr_[i] = broadcast_data;
|
||||
}
|
||||
} else {
|
||||
bias_ptr_ = reinterpret_cast<float *>(malloc(max_bias_data * sizeof(float)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias_ptr_ failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_ptr_, 0, max_bias_data * sizeof(float));
|
||||
memcpy(bias_ptr_, bias_tensor->data_c(), bias_tensor->ElementsNum() * sizeof(float));
|
||||
}
|
||||
memset(bias_ptr_, 0, max_bias_data * sizeof(float));
|
||||
memcpy(bias_ptr_, bias_tensor->data_c(), bias_tensor->ElementsNum() * sizeof(float));
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -56,6 +56,7 @@ class MatmulFp32BaseCPUKernel : public LiteKernel {
|
|||
void ResizeParameter();
|
||||
void FreeResizeBufA();
|
||||
void FreeResizeBufB();
|
||||
int CalBroadCastBiasDataElements();
|
||||
|
||||
protected:
|
||||
MatMulParameter *params_ = nullptr;
|
||||
|
|
|
@ -6,7 +6,6 @@ set(CCSRC_SRC
|
|||
${CCSRC_DIR}/backend/optimizer/common/visit.cc
|
||||
${CCSRC_DIR}/backend/optimizer/common/optimizer.cc
|
||||
)
|
||||
set(NNACL_DIR ${CCSRC_DIR}/backend/kernel_compiler/cpu/nnacl)
|
||||
|
||||
include(${TOP_DIR}/cmake/external_libs/glog.cmake)
|
||||
include_directories(${TOP_DIR}/mindspore/ccsrc/backend/kernel_compiler/cpu)
|
||||
|
|
Loading…
Reference in New Issue