forked from mindspore-Ecosystem/mindspore
!7716 [lite] support complex64 and convertion about audio model and fix bug
Merge pull request !7716 from 徐安越/master
This commit is contained in:
commit
97cde36843
|
@ -47,6 +47,12 @@ Float::Float(const int nbits) : Number(FloatBitsToTypeId(nbits), nbits, false) {
|
|||
}
|
||||
}
|
||||
|
||||
Complex::Complex(const int nbits) : Number(TypeId::kNumberTypeComplex64, nbits, false) {
|
||||
if (nbits != 64) {
|
||||
MS_LOG(EXCEPTION) << "Wrong number of bits.";
|
||||
}
|
||||
}
|
||||
|
||||
const TypePtr kBool = std::make_shared<Bool>();
|
||||
const TypePtr kInt8 = std::make_shared<Int>(8);
|
||||
const TypePtr kInt16 = std::make_shared<Int>(16);
|
||||
|
@ -63,4 +69,5 @@ const TypePtr kInt = std::make_shared<Int>();
|
|||
const TypePtr kUInt = std::make_shared<UInt>();
|
||||
const TypePtr kFloat = std::make_shared<Float>();
|
||||
const TypePtr kNumber = std::make_shared<Number>();
|
||||
const TypePtr kComplex64 = std::make_shared<Complex>(64);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -150,6 +150,28 @@ class Float : public Number {
|
|||
}
|
||||
};
|
||||
|
||||
// Complex
|
||||
class Complex : public Number {
|
||||
public:
|
||||
Complex() : Number(kNumberTypeComplex64, 0) {}
|
||||
explicit Complex(const int nbits);
|
||||
~Complex() override {}
|
||||
MS_DECLARE_PARENT(Complex, Number)
|
||||
|
||||
TypeId generic_type_id() const override { return kNumberTypeComplex64; }
|
||||
TypePtr DeepCopy() const override {
|
||||
if (nbits() == 0) {
|
||||
return std::make_shared<Complex>();
|
||||
}
|
||||
return std::make_shared<Complex>(nbits());
|
||||
}
|
||||
std::string ToString() const override { return GetTypeName("Complex64"); }
|
||||
std::string ToReprString() const override { return nbits() == 0 ? "complex64_" : GetTypeName("complex64"); }
|
||||
std::string DumpText() const override {
|
||||
return nbits() == 0 ? std::string("Complex64") : std::string("C") + std::to_string(nbits());
|
||||
}
|
||||
};
|
||||
|
||||
extern const TypePtr kBool;
|
||||
extern const TypePtr kInt8;
|
||||
extern const TypePtr kInt16;
|
||||
|
@ -166,6 +188,7 @@ extern const TypePtr kInt;
|
|||
extern const TypePtr kUInt;
|
||||
extern const TypePtr kFloat;
|
||||
extern const TypePtr kNumber;
|
||||
extern const TypePtr kComplex64;
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_IR_DTYPE_NUMBER_H_
|
||||
|
|
|
@ -69,6 +69,8 @@ TypePtr TypeIdToType(TypeId id) {
|
|||
return kFloat32;
|
||||
case kNumberTypeFloat64:
|
||||
return kFloat64;
|
||||
case kNumberTypeComplex64:
|
||||
return kComplex64;
|
||||
case kNumberTypeInt8:
|
||||
return kInt8;
|
||||
case kNumberTypeInt16:
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/audio_spectrogram.h"
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int AudioSpectrogram::GetWindowSize() const { return this->primitive_->value.AsAudioSpectrogram()->windowSize; }
|
||||
int AudioSpectrogram::GetStride() const { return this->primitive_->value.AsAudioSpectrogram()->stride; }
|
||||
bool AudioSpectrogram::GetMagSquare() const { return this->primitive_->value.AsAudioSpectrogram()->magSquare; }
|
||||
|
||||
#else
|
||||
int AudioSpectrogram::GetWindowSize() const { return this->primitive_->value_as_AudioSpectrogram()->windowSize(); }
|
||||
int AudioSpectrogram::GetStride() const { return this->primitive_->value_as_AudioSpectrogram()->stride(); }
|
||||
bool AudioSpectrogram::GetMagSquare() const { return this->primitive_->value_as_AudioSpectrogram()->magSquare(); }
|
||||
int AudioSpectrogram::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_AudioSpectrogram();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Add return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateAudioSpectrogram(*fbb, attr->windowSize(), attr->stride(), attr->magSquare());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_AudioSpectrogram, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *AudioSpectrogramCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<AudioSpectrogram>(primitive);
|
||||
}
|
||||
Registry AudioSpectrogramRegistry(schema::PrimitiveType_AudioSpectrogram, AudioSpectrogramCreator);
|
||||
#endif
|
||||
int AudioSpectrogram::Log2Ceil(uint32_t length) {
|
||||
if (length == 0) {
|
||||
return -1;
|
||||
}
|
||||
int floor = 0;
|
||||
for (int i = 4; i >= 0; --i) {
|
||||
int shift = (1 << i);
|
||||
uint32_t tmp = length >> shift;
|
||||
if (tmp != 0) {
|
||||
length = tmp;
|
||||
floor += shift;
|
||||
}
|
||||
}
|
||||
return length == (length & ~(length - 1)) ? floor : floor + 1;
|
||||
}
|
||||
uint32_t AudioSpectrogram::GetFftLength(uint32_t length) {
|
||||
int shift = Log2Ceil(length);
|
||||
return 1 << shift;
|
||||
}
|
||||
int AudioSpectrogram::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
if (input_shape.size() != 2) {
|
||||
MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (GetWindowSize() < 2) {
|
||||
MS_LOG(ERROR) << "window size is too short, now is " << GetWindowSize();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (GetStride() < 1) {
|
||||
MS_LOG(ERROR) << "stride must be positive, now is " << GetStride();
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<int> output_shape(3);
|
||||
output_shape[0] = input_shape[1];
|
||||
// output height
|
||||
int sample_sub_window = input_shape[0] - GetWindowSize();
|
||||
output_shape[1] = sample_sub_window < 0 ? 0 : 1 + sample_sub_window / GetStride();
|
||||
// compute fft length
|
||||
int fft_length = GetFftLength(GetWindowSize());
|
||||
output_shape[2] = fft_length / 2 + 1;
|
||||
outputs_.front()->set_shape(output_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_AUDIO_SPECTROGRAM_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_AUDIO_SPECTROGRAM_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class AudioSpectrogram : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(AudioSpectrogram, PrimitiveC);
|
||||
AudioSpectrogram() = default;
|
||||
explicit AudioSpectrogram(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetWindowSize(int window_size) { this->primitive_->value.AsAudioSpectrogram()->windowSize = window_size; }
|
||||
void SetStride(int stride) { this->primitive_->value.AsAudioSpectrogram()->stride = stride; }
|
||||
void SetMagSquare(bool mag_square) { this->primitive_->value.AsAudioSpectrogram()->magSquare = mag_square; }
|
||||
#else
|
||||
AudioSpectrogram() = default;
|
||||
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int GetWindowSize() const;
|
||||
int GetStride() const;
|
||||
bool GetMagSquare() const;
|
||||
int Log2Ceil(uint32_t length);
|
||||
uint32_t GetFftLength(uint32_t length);
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_AUDIO_SPECTROGRAM_H_
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/fft_imag.h"
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
int FftImag::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto val_offset = schema::CreateEqual(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FftImag, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *FftImagCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<FftImag>(primitive); }
|
||||
Registry FftImagRegistry(schema::PrimitiveType_FftImag, FftImagCreator);
|
||||
#endif
|
||||
int FftImag::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_data_type(TypeId::kNumberTypeFloat32);
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
input_shape.pop_back();
|
||||
outputs_.front()->set_shape(input_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_FFT_IMAG_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_FFT_IMAG_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class FftImag : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(FftImag, PrimitiveC);
|
||||
FftImag() = default;
|
||||
explicit FftImag(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
FftImag() = default;
|
||||
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_FFT_IMAG_H_
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/fft_real.h"
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
int FftReal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto val_offset = schema::CreateEqual(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FftReal, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *FftRealCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<FftReal>(primitive); }
|
||||
Registry FftRealRegistry(schema::PrimitiveType_FftReal, FftRealCreator);
|
||||
#endif
|
||||
int FftReal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_data_type(TypeId::kNumberTypeFloat32);
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
input_shape.pop_back();
|
||||
outputs_.front()->set_shape(input_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_FFT_REAL_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_FFT_REAL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class FftReal : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(FftReal, PrimitiveC);
|
||||
FftReal() = default;
|
||||
explicit FftReal(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
FftReal() = default;
|
||||
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_FFT_REAL_H_
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/mfcc.h"
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
float Mfcc::GetFreqUpperLimit() const { return this->primitive_->value.AsMfcc()->freqUpperLimit; }
|
||||
float Mfcc::GetFreqLowerLimit() const { return this->primitive_->value.AsMfcc()->freqLowerLimit; }
|
||||
int Mfcc::GetFilterBankChannelNum() const { return this->primitive_->value.AsMfcc()->filterBankChannelNum; }
|
||||
int Mfcc::GetDctCoeffNum() const { return this->primitive_->value.AsMfcc()->dctCoeffNum; }
|
||||
|
||||
#else
|
||||
float Mfcc::GetFreqUpperLimit() const { return this->primitive_->value_as_Mfcc()->freqUpperLimit(); }
|
||||
float Mfcc::GetFreqLowerLimit() const { return this->primitive_->value_as_Mfcc()->freqLowerLimit(); }
|
||||
int Mfcc::GetFilterBankChannelNum() const { return this->primitive_->value_as_Mfcc()->filterBankChannelNum(); }
|
||||
int Mfcc::GetDctCoeffNum() const { return this->primitive_->value_as_Mfcc()->dctCoeffNum(); }
|
||||
int Mfcc::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Mfcc();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Add return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateMfcc(*fbb, attr->freqUpperLimit(), attr->freqLowerLimit(),
|
||||
attr->filterBankChannelNum(), attr->dctCoeffNum());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Mfcc, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *MfccCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Mfcc>(primitive); }
|
||||
Registry MfccRegistry(schema::PrimitiveType_Mfcc, MfccCreator);
|
||||
#endif
|
||||
int Mfcc::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
if (input_shape.size() != 3) {
|
||||
MS_LOG(ERROR) << "first input shape is error, which need to be 3 dimensions, but the dimension is "
|
||||
<< input_shape.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (inputs_[1]->ElementsNum() != 1) {
|
||||
MS_LOG(ERROR) << "second input element num is error, which need only a value, but the number is "
|
||||
<< inputs_[1]->ElementsNum();
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<int> output_shape(3);
|
||||
output_shape[0] = input_shape[0];
|
||||
output_shape[1] = input_shape[1];
|
||||
output_shape[2] = GetDctCoeffNum();
|
||||
outputs_.front()->set_shape(output_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_MFCC_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_MFCC_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Mfcc : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Mfcc, PrimitiveC);
|
||||
Mfcc() = default;
|
||||
explicit Mfcc(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetFreqUpperLimit(float freq_upper_limit) {
|
||||
this->primitive_->value.AsMfcc()->freqUpperLimit = freq_upper_limit;
|
||||
}
|
||||
void SetFreqLowerLimit(float freq_lower_limit) {
|
||||
this->primitive_->value.AsMfcc()->freqLowerLimit = freq_lower_limit;
|
||||
}
|
||||
void SetFilterBankChannelNum(int filter_bank_channel_num) {
|
||||
this->primitive_->value.AsMfcc()->filterBankChannelNum = filter_bank_channel_num;
|
||||
}
|
||||
void SetDctCoeffNum(int dct_coeff_num) { this->primitive_->value.AsMfcc()->dctCoeffNum = dct_coeff_num; }
|
||||
#else
|
||||
Mfcc() = default;
|
||||
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
float GetFreqUpperLimit() const;
|
||||
float GetFreqLowerLimit() const;
|
||||
int GetFilterBankChannelNum() const;
|
||||
int GetDctCoeffNum() const;
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_MFCC_H_
|
|
@ -137,6 +137,11 @@
|
|||
#include "src/ops/upsample.h"
|
||||
#include "src/ops/layer_norm.h"
|
||||
#include "src/ops/non_max_suppression.h"
|
||||
#include "src/ops/rfft.h"
|
||||
#include "src/ops/fft_real.h"
|
||||
#include "src/ops/fft_imag.h"
|
||||
#include "src/ops/audio_spectrogram.h"
|
||||
#include "src/ops/mfcc.h"
|
||||
#include "src/ops/identity.h"
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
|
@ -775,6 +780,16 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
|
|||
return new NonMaxSuppression(primitive);
|
||||
case schema::PrimitiveType_Identity:
|
||||
return new Identity(primitive);
|
||||
case schema::PrimitiveType_Rfft:
|
||||
return new Rfft(primitive);
|
||||
case schema::PrimitiveType_FftReal:
|
||||
return new FftReal(primitive);
|
||||
case schema::PrimitiveType_FftImag:
|
||||
return new FftImag(primitive);
|
||||
case schema::PrimitiveType_AudioSpectrogram:
|
||||
return new AudioSpectrogram(primitive);
|
||||
case schema::PrimitiveType_Mfcc:
|
||||
return new Mfcc(primitive);
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
case schema::PrimitiveType_ActivationGrad:
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/rfft.h"
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Rfft::GetFftLength() const { return this->primitive_->value.AsRfft()->fftLength; }
|
||||
|
||||
void Rfft::SetFftLength(int fft_length) { this->primitive_->value.AsRfft()->fftLength = fft_length; }
|
||||
|
||||
#else
|
||||
int Rfft::GetFftLength() const { return this->primitive_->value_as_Rfft()->fftLength(); }
|
||||
int Rfft::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Rfft();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Add return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateRfft(*fbb, attr->fftLength());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Rfft, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *RfftCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Rfft>(primitive); }
|
||||
Registry RfftRegistry(schema::PrimitiveType_Rfft, RfftCreator);
|
||||
#endif
|
||||
int Rfft::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_data_type(TypeId::kNumberTypeComplex64);
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
input_shape[input_shape.size() - 1] = GetFftLength() / 2 + 1;
|
||||
input_shape.push_back(2);
|
||||
outputs_.front()->set_shape(input_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_RFFT_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_RFFT_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Rfft : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Rfft, PrimitiveC);
|
||||
Rfft() = default;
|
||||
explicit Rfft(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetFftLength(int fft_length);
|
||||
#else
|
||||
Rfft() = default;
|
||||
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int GetFftLength() const;
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_RFFT_H_
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/caffe/caffe_elu_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS CaffeEluParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight,
|
||||
schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) {
|
||||
MS_LOG(DEBUG) << "parse CaffeEluParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::EluT> attr = std::make_unique<schema::EluT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (proto.has_elu_param()) {
|
||||
const caffe::ELUParameter eluParameter = proto.elu_param();
|
||||
if (eluParameter.has_alpha()) {
|
||||
attr->alpha = eluParameter.alpha();
|
||||
}
|
||||
}
|
||||
|
||||
op->name = proto.name();
|
||||
op->primitive->value.type = schema::PrimitiveType_Elu;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
CaffeNodeRegistrar g_caffeEluParser("ELU", new CaffeEluParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_ELU_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_ELU_PARSER_H_
|
||||
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/caffe/caffe_node_parser.h"
|
||||
#include "tools/converter/parser/caffe/caffe_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class CaffeEluParser : public CaffeNodeParser {
|
||||
public:
|
||||
CaffeEluParser() : CaffeNodeParser("elu") {}
|
||||
|
||||
STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op,
|
||||
std::vector<schema::TensorT *> *weightVec) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_ELU_PARSER_H_
|
|
@ -243,7 +243,7 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff
|
|||
auto status_node = nodeParser->Parse(layer, layerP, op.get(), &weightVec);
|
||||
if (status_node != RET_OK) {
|
||||
interrupt = true;
|
||||
if (status_node == RET_NOT_SUPPORT) {
|
||||
if (status_node == RET_NOT_FIND_OP) {
|
||||
NoSupportOp::GetInstance()->InsertOp(layer.type());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!";
|
||||
|
|
|
@ -156,8 +156,8 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
|
||||
if (attr->group != 1) {
|
||||
if (!ParseGroupDeConvolution(attr, op)) {
|
||||
MS_LOG(ERROR) << "Convert DeConvolution to DeDepthwise failed";
|
||||
return RET_ERROR;
|
||||
MS_LOG(ERROR) << "Convert DeConvolution to DeDepthwise failed, generalized group deconv hasn't support";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
} else {
|
||||
op->primitive->value.type = schema::PrimitiveType_DeConv2D;
|
||||
|
|
|
@ -522,6 +522,7 @@ STATUS OnnxModelParser::ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodePr
|
|||
dst_op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_graph, const QuantType &quantType) {
|
||||
TensorCache tensor_cache;
|
||||
// dst_graph->name = onnx_graph.name(); // this is not used
|
||||
|
@ -593,6 +594,7 @@ schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_gra
|
|||
SetAllTensors(tensor_cache, dst_graph.get());
|
||||
return dst_graph.release();
|
||||
}
|
||||
|
||||
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
||||
const QuantType &quantType) {
|
||||
int status = ValidateFileStr(modelFile, ".onnx");
|
||||
|
|
|
@ -72,11 +72,12 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
|
|||
auto data_index = tflite_op->inputs[0];
|
||||
const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index];
|
||||
std::vector<int> params;
|
||||
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW,
|
||||
¶ms) != RET_OK) {
|
||||
int status =
|
||||
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "get padding params failed";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
} else if (status == RET_OK) {
|
||||
attr->padUp = params.at(0);
|
||||
attr->padDown = params.at(1);
|
||||
attr->padLeft = params.at(2);
|
||||
|
|
|
@ -73,11 +73,12 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
auto data_index = tflite_op->inputs[2];
|
||||
const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index];
|
||||
std::vector<int> params;
|
||||
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW,
|
||||
¶ms) != RET_OK) {
|
||||
int status =
|
||||
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "get padding params failed";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
} else if (status == RET_OK) {
|
||||
attr->padUp = params.at(0);
|
||||
attr->padDown = params.at(1);
|
||||
attr->padLeft = params.at(2);
|
||||
|
|
|
@ -79,11 +79,12 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
|
||||
// calculate pad params
|
||||
std::vector<int> params;
|
||||
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW,
|
||||
¶ms) != RET_OK) {
|
||||
int status =
|
||||
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "get padding params failed";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
} else if (status == RET_OK) {
|
||||
attr->padUp = params.at(0);
|
||||
attr->padDown = params.at(1);
|
||||
attr->padLeft = params.at(2);
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <utility>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/storage.h"
|
||||
#include "flatbuffers/flatbuffers.h"
|
||||
|
@ -102,11 +103,6 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
|
|||
for (const auto &tflite_op : tflite_subgraph->operators) {
|
||||
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
|
||||
auto op_type = GetMSOpType(tflite_op_type);
|
||||
if (op_type == "CUSTOM") {
|
||||
auto custom_type = (tflite_model->operator_codes[tflite_op->opcode_index])->custom_code;
|
||||
MS_LOG(ERROR) << "CUSTOM op is not supported, the type is " << custom_type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto op = std::make_unique<schema::CNodeT>();
|
||||
op->name = op_type + "-" + std::to_string(idx++);
|
||||
|
@ -122,7 +118,9 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
|
|||
if (status == RET_OK) {
|
||||
status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, op.get());
|
||||
if (status != RET_OK) {
|
||||
if (status == RET_NOT_SUPPORT) {
|
||||
if (status == RET_NOT_FIND_OP) {
|
||||
op_type =
|
||||
(op_type != "Custom" ? op_type : (tflite_model->operator_codes[tflite_op->opcode_index])->custom_code);
|
||||
NoSupportOp::GetInstance()->InsertOp(op_type);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed";
|
||||
|
@ -141,6 +139,16 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
|
|||
STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::MetaGraphT *sub_graph) {
|
||||
std::set<int> output_index;
|
||||
for (const auto &tflite_op : tflite_subgraph->operators) {
|
||||
for (size_t j = 0; j < tflite_op->outputs.size(); ++j) {
|
||||
int idx = tflite_op->outputs[j];
|
||||
if (idx < 0) {
|
||||
idx += tflite_subgraph->tensors.size();
|
||||
}
|
||||
output_index.insert(idx);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < tensorsInfo.tensorsId.size(); i++) {
|
||||
auto idx = tensorsInfo.tensorsId[i];
|
||||
if (idx < 0) {
|
||||
|
@ -173,11 +181,16 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT>
|
|||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
// set tensor attr
|
||||
if (isInput || isConst) {
|
||||
tensor->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
} else {
|
||||
tensor->nodeType = schema::NodeType_Parameter;
|
||||
if (output_index.find(idx) == output_index.end() && tflite_tensor->shape[0] == 0) {
|
||||
tensor->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
} else {
|
||||
tensor->nodeType = schema::NodeType_Parameter;
|
||||
}
|
||||
}
|
||||
|
||||
// quant param
|
||||
|
@ -246,7 +259,6 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph)
|
|||
if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
|
||||
auto attr = op->primitive->value.AsDepthwiseConv2D();
|
||||
if (attr->channelMultiplier > 1) {
|
||||
std::unique_ptr<schema::Conv2DT> conv_attr = std::make_unique<schema::Conv2DT>();
|
||||
// get channel attr
|
||||
if (op->inputIndex.empty()) {
|
||||
MS_LOG(ERROR) << "the input of DepthwiseConv2D is null";
|
||||
|
@ -263,7 +275,11 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph)
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
auto data_shape = data_tensor->dims;
|
||||
|
||||
if (data_shape.empty()) {
|
||||
MS_LOG(DEBUG) << "the tensor's shape is dynamic, which obtain only when running";
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
std::unique_ptr<schema::Conv2DT> conv_attr = std::make_unique<schema::Conv2DT>();
|
||||
if (data_shape[3] == 1) {
|
||||
conv_attr->channelIn = data_shape[3];
|
||||
conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier;
|
||||
|
@ -372,7 +388,7 @@ schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file,
|
|||
|
||||
// update for depthwiseConv
|
||||
status = ConvertGroupDepthwiseOp(meta_graph.get());
|
||||
if (status != RET_OK) {
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "convert group depthwise conv failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
|
|
|
@ -86,6 +86,10 @@ class TfliteNodeParser {
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
auto data_ptr = buf_data->data.data();
|
||||
if (data_ptr == nullptr) {
|
||||
MS_LOG(DEBUG) << "data is not a constant";
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
switch (tflite_tensors[tensor_index]->type) {
|
||||
case tflite::TensorType_UINT8: {
|
||||
for (int i = 0; i < count; i++) {
|
||||
|
|
|
@ -71,11 +71,11 @@ STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique
|
|||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support";
|
||||
return RET_INVALID_OP_ATTR;
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "this pad:" << node_name << " hasn't been supported";
|
||||
return RET_NOT_SUPPORT;
|
||||
return RET_NOT_FIND_OP;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Pad;
|
||||
|
|
|
@ -71,11 +71,12 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
|
|||
auto data_index = tflite_op->inputs[0];
|
||||
const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index];
|
||||
std::vector<int> params;
|
||||
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW,
|
||||
¶ms) != RET_OK) {
|
||||
int status =
|
||||
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, ¶ms);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "get padding params failed";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
} else if (status == RET_OK) {
|
||||
attr->padUp = params.at(0);
|
||||
attr->padDown = params.at(1);
|
||||
attr->padLeft = params.at(2);
|
||||
|
|
|
@ -41,15 +41,31 @@ STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq
|
|||
}
|
||||
|
||||
attr->dType = 0;
|
||||
// attr->start
|
||||
// attr->limit
|
||||
// attr->delta
|
||||
|
||||
std::vector<int> limit;
|
||||
std::vector<int> delta;
|
||||
int status = GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, limit);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "range -> limit get failed";
|
||||
return RET_ERROR;
|
||||
} else if (status == RET_OK) {
|
||||
status = GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, delta);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "stridedSlice -> end get failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
if (status == RET_OK) {
|
||||
attr->limit = limit.front();
|
||||
attr->delta = delta.front();
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_Range;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
int input_num = status == RET_OK ? 1 : 3;
|
||||
for (int i = 0; i < input_num; ++i) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
|
|
|
@ -69,7 +69,7 @@ STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
} else if (std::strcmp(node_name, "ReduceAny") == 0) {
|
||||
// attr->mode;
|
||||
MS_LOG(ERROR) << "ms-lite haven't supported REDUCE_ANY now";
|
||||
return RET_NOT_FIND_OP;
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->axes)) {
|
||||
|
|
|
@ -67,14 +67,15 @@ STATUS TfliteSplitParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq
|
|||
return RET_ERROR;
|
||||
}
|
||||
attr->splitDim = axis;
|
||||
if (tensor_shape[axis] % num_splits != 0) {
|
||||
if (tensor_shape[axis] % num_splits != 0 && tensor_shape[axis] / num_splits != 0) {
|
||||
MS_LOG(ERROR) << "num_splits can't divide tensor's length at axis " << axis;
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->numberSplit = num_splits;
|
||||
|
||||
for (int i = 0; i < num_splits; i++) {
|
||||
attr->sizeSplits.push_back(tensor_shape[axis] / num_splits);
|
||||
if (tensor_shape[axis] / num_splits != 0) {
|
||||
for (int i = 0; i < num_splits; i++) {
|
||||
attr->sizeSplits.push_back(tensor_shape[axis] / num_splits);
|
||||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Split;
|
||||
|
|
|
@ -52,17 +52,24 @@ STATUS TfliteStridedSliceParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
attr->newAxisMask = tflite_attr->new_axis_mask;
|
||||
attr->shrinkAxisMask = tflite_attr->shrink_axis_mask;
|
||||
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->begin)) {
|
||||
int status =
|
||||
GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->begin);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "stridedSlice -> begin get failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->end)) {
|
||||
MS_LOG(ERROR) << "stridedSlice -> end get failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (GetTfliteData(tflite_op->inputs[3], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->stride)) {
|
||||
MS_LOG(ERROR) << "stridedSlice -> stride get failed";
|
||||
return RET_ERROR;
|
||||
} else if (status == RET_OK) {
|
||||
status = GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->end);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "stridedSlice -> end get failed";
|
||||
return RET_ERROR;
|
||||
} else if (status == RET_OK) {
|
||||
status =
|
||||
GetTfliteData(tflite_op->inputs[3], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->stride);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "stridedSlice -> stride get failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
attr->isScale.assign(tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]->shape.begin(),
|
||||
tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]->shape.end());
|
||||
|
@ -70,8 +77,11 @@ STATUS TfliteStridedSliceParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
op->primitive->value.type = schema::PrimitiveType_StridedSlice;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
int input_num = status == RET_OK ? 1 : 4;
|
||||
for (int i = 0; i < input_num; ++i) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
|
|
|
@ -198,7 +198,10 @@ STATUS getPaddingParam(const std::unique_ptr<tflite::TensorT> &tensor, schema::P
|
|||
MS_LOG(ERROR) << "the input tensor is null";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (tensor->shape.empty()) {
|
||||
MS_LOG(DEBUG) << "the tensor's shape is dynamic, which obtain nly when running.";
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
int padUp = 0;
|
||||
int padDown = 0;
|
||||
int padLeft = 0;
|
||||
|
|
Loading…
Reference in New Issue