forked from mindspore-Ecosystem/mindspore
!6990 MSLITE adjust conv_param position and add tflite custom parser
Merge pull request !6990 from 徐安越/master
This commit is contained in:
commit
f9353bb963
|
@ -78,6 +78,7 @@ enum TypeId : int {
|
|||
kNumberTypeFloat16,
|
||||
kNumberTypeFloat32,
|
||||
kNumberTypeFloat64,
|
||||
kNumberTypeComplex64,
|
||||
kNumberTypeEnd
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -212,6 +212,9 @@ union PrimitiveType {
|
|||
CustomExtractFeatures,
|
||||
AudioSpectrogram,
|
||||
Mfcc,
|
||||
Rfft,
|
||||
FftReal,
|
||||
FftImag,
|
||||
}
|
||||
|
||||
enum QuantType: int {
|
||||
|
|
|
@ -987,4 +987,14 @@ table Mfcc {
|
|||
freqLowerLimit : float;
|
||||
filterBankChannelNum : int;
|
||||
dctCoeffNum : int;
|
||||
}
|
||||
}
|
||||
|
||||
table Rfft {
|
||||
fftLength : int;
|
||||
}
|
||||
|
||||
table FftReal {
|
||||
}
|
||||
|
||||
table FftImag {
|
||||
}
|
||||
|
|
|
@ -110,7 +110,8 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect
|
|||
for (const auto &kernel : input_kernels) {
|
||||
for (const auto &tensor : kernel->in_tensors()) {
|
||||
auto iter = std::find(all_output_tensors.begin(), all_output_tensors.end(), tensor);
|
||||
if (iter == all_output_tensors.end() && tensor->data_c() == nullptr) {
|
||||
if (iter == all_output_tensors.end() &&
|
||||
!(tensor->category() == mindspore::lite::Tensor::CONST && tensor->data_c() != nullptr)) {
|
||||
input_tensors.emplace_back(tensor);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -171,16 +171,16 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
|
|||
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter);
|
||||
int kernel_h = conv_param->kernel_h_;
|
||||
int kernel_w = conv_param->kernel_w_;
|
||||
conv_param->input_h_ = inputs.front()->Height();
|
||||
conv_param->input_w_ = inputs.front()->Width();
|
||||
conv_param->input_channel_ = inputs.front()->Channel();
|
||||
conv_param->output_h_ = outputs.front()->Height();
|
||||
conv_param->output_w_ = outputs.front()->Width();
|
||||
conv_param->output_channel_ = outputs.front()->Channel();
|
||||
conv_param->op_parameter_.thread_num_ = ctx->thread_num_;
|
||||
bool use_winograd = false;
|
||||
int out_unit;
|
||||
if (primitive != nullptr && primitive->GetInferFlag()) {
|
||||
conv_param->input_h_ = inputs.front()->Height();
|
||||
conv_param->input_w_ = inputs.front()->Width();
|
||||
conv_param->input_channel_ = inputs.front()->Channel();
|
||||
conv_param->output_h_ = outputs.front()->Height();
|
||||
conv_param->output_w_ = outputs.front()->Width();
|
||||
conv_param->output_channel_ = outputs.front()->Channel();
|
||||
conv_param->op_parameter_.thread_num_ = ctx->thread_num_;
|
||||
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param);
|
||||
}
|
||||
|
||||
|
|
|
@ -137,6 +137,49 @@ STATUS TfliteCustomParser::ExtractFeatures(const std::vector<uint8_t> &custom_at
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TfliteCustomParser::Rfft(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
std::unique_ptr<schema::RfftT> attr = std::make_unique<schema::RfftT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
std::vector<int> fft_length;
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, fft_length)) {
|
||||
MS_LOG(ERROR) << "rfft -> fftLength get failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->fftLength = fft_length[0];
|
||||
op->primitive->value.type = schema::PrimitiveType_Rfft;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TfliteCustomParser::FftReal(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
|
||||
std::unique_ptr<schema::FftRealT> attr = std::make_unique<schema::FftRealT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_FftReal;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TfliteCustomParser::FftImag(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
|
||||
std::unique_ptr<schema::FftImagT> attr = std::make_unique<schema::FftImagT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_FftImag;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteCustomParser";
|
||||
|
@ -163,6 +206,12 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
status = ExtractFeatures(custom_attr, op, tflite_op);
|
||||
} else if (custom_type == "AudioSpectrogram") {
|
||||
status = AudioSpectrogram(custom_attr, op, tflite_op);
|
||||
} else if (custom_type == "FlexRFFT") {
|
||||
status = Rfft(custom_attr, op, tflite_op, tflite_model);
|
||||
} else if (custom_type == "FlexReal") {
|
||||
status = FftReal(custom_attr, op, tflite_op);
|
||||
} else if (custom_type == "FlexImag") {
|
||||
status = FftImag(custom_attr, op, tflite_op);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "the custom op hasn't been supported now";
|
||||
status = RET_NOT_FIND_OP;
|
||||
|
|
|
@ -49,6 +49,15 @@ class TfliteCustomParser : public TfliteNodeParser {
|
|||
|
||||
STATUS ExtractFeatures(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op);
|
||||
|
||||
STATUS Rfft(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op, const std::unique_ptr<tflite::ModelT> &tflite_model);
|
||||
|
||||
STATUS FftReal(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op);
|
||||
|
||||
STATUS FftImag(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -133,12 +133,12 @@ std::map<tflite::ActivationFunctionType, schema::ActivationType> tfMsActivationF
|
|||
};
|
||||
|
||||
std::map<int, TypeId> type_map = {
|
||||
{tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64}, {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32},
|
||||
{tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, {tflite::TensorType_INT32, TypeId::kNumberTypeInt32},
|
||||
{tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, {tflite::TensorType_INT8, TypeId::kNumberTypeInt8},
|
||||
{tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8},
|
||||
{tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, {tflite::TensorType_STRING, TypeId::kObjectTypeString},
|
||||
};
|
||||
{tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64}, {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32},
|
||||
{tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, {tflite::TensorType_INT32, TypeId::kNumberTypeInt32},
|
||||
{tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, {tflite::TensorType_INT8, TypeId::kNumberTypeInt8},
|
||||
{tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8},
|
||||
{tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, {tflite::TensorType_STRING, TypeId::kObjectTypeString},
|
||||
{tflite::TensorType_COMPLEX64, TypeId::kNumberTypeComplex64}};
|
||||
|
||||
schema::ActivationType GetActivationFunctionType(tflite::ActivationFunctionType tfliteAFType) {
|
||||
return tfMsActivationFunctionMap.at(tfliteAFType);
|
||||
|
|
Loading…
Reference in New Issue