commit
6f0eb82333
|
@ -37,13 +37,13 @@ class MSOpsRegistry {
|
|||
}
|
||||
void InsertPrimitiveTMap(const std::string &name, PrimitiveTCreator creator) {
|
||||
std::string lower_name = name;
|
||||
std::transform(name.begin(), name.end(), lower_name.begin(), ::tolower);
|
||||
(void)std::transform(name.begin(), name.end(), lower_name.begin(), ::tolower);
|
||||
primitive_creators[lower_name] = creator;
|
||||
}
|
||||
PrimitiveTCreator GetPrimitiveCreator(const std::string &name) {
|
||||
std::string lower_name = name;
|
||||
std::transform(name.begin(), name.end(), lower_name.begin(), ::tolower);
|
||||
lower_name.erase(std::remove(lower_name.begin(), lower_name.end(), '_'), lower_name.end());
|
||||
(void)std::transform(name.begin(), name.end(), lower_name.begin(), ::tolower);
|
||||
(void)lower_name.erase(std::remove(lower_name.begin(), lower_name.end(), '_'), lower_name.end());
|
||||
if (primitive_creators.find(lower_name) != primitive_creators.end()) {
|
||||
return primitive_creators[lower_name];
|
||||
} else {
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
constexpr size_t INITIAL_SIZE = 1024;
|
||||
const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
const schema::Primitive *ConvertToPrimitive(const schema::PrimitiveT *primitive_t,
|
||||
flatbuffers::FlatBufferBuilder *fbb) {
|
||||
if (primitive_t == nullptr || fbb == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveT or fbb is nullptr.";
|
||||
return nullptr;
|
||||
|
@ -33,12 +34,13 @@ const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, fla
|
|||
return flatbuffers::GetRoot<schema::Primitive>(prim_buf);
|
||||
}
|
||||
|
||||
OpParameter *GetOpParameter(schema::PrimitiveT *primitive_t) {
|
||||
OpParameter *GetOpParameter(const schema::PrimitiveT *primitive_t) {
|
||||
flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE);
|
||||
auto primitive = ConvertToPrimitive(primitive_t, &fbb);
|
||||
fbb.Clear();
|
||||
auto prim_type = GetPrimitiveType(primitive, SCHEMA_VERSION::SCHEMA_CUR);
|
||||
auto parame_gen = PopulateRegistry::GetInstance()->GetParameterCreator(prim_type, SCHEMA_VERSION::SCHEMA_CUR);
|
||||
auto prim_type = GetPrimitiveType(primitive, static_cast<int>(SCHEMA_VERSION::SCHEMA_CUR));
|
||||
auto parame_gen =
|
||||
PopulateRegistry::GetInstance()->GetParameterCreator(prim_type, static_cast<int>(SCHEMA_VERSION::SCHEMA_CUR));
|
||||
if (parame_gen == nullptr) {
|
||||
MS_LOG(ERROR) << "parameter generator is nullptr.";
|
||||
return nullptr;
|
||||
|
@ -46,7 +48,7 @@ OpParameter *GetOpParameter(schema::PrimitiveT *primitive_t) {
|
|||
auto parameter = parame_gen(primitive);
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
|
||||
<< GetPrimitiveTypeName(primitive, SCHEMA_VERSION::SCHEMA_CUR);
|
||||
<< GetPrimitiveTypeName(primitive, static_cast<int>(SCHEMA_VERSION::SCHEMA_CUR));
|
||||
}
|
||||
return parameter;
|
||||
}
|
||||
|
|
|
@ -24,8 +24,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb);
|
||||
OpParameter *GetOpParameter(schema::PrimitiveT *primitive_t);
|
||||
const schema::Primitive *ConvertToPrimitive(const schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb);
|
||||
OpParameter *GetOpParameter(const schema::PrimitiveT *primitive_t);
|
||||
std::unique_ptr<schema::PrimitiveT> GetPrimitiveT(const std::shared_ptr<mindspore::ops::BaseOperator> &op);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue