[MS][LITE] solve static check

[MS][LITE] add
This commit is contained in:
cjh9368 2021-04-26 15:10:31 +08:00
parent d328fdcf61
commit c91b7d15fa
11 changed files with 32 additions and 4 deletions

View File

@ -24,6 +24,7 @@
namespace mindspore {
namespace lite {
int GetPrimitiveType(const void *primitive) {
MS_ASSERT(primitive != nullptr);
if (primitive == nullptr) {
return -1;
}
@ -51,6 +52,7 @@ const char *PrimitiveCurVersionTypeName(int type) {
int GenPrimVersionKey(int primitive_type, int schema_version) { return primitive_type * 1000 + schema_version; }
bool IsPartialNode(const void *primitive) {
MS_ASSERT(primitive != nullptr);
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
if (schema_version == SCHEMA_CUR) {
return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_PartialFusion;
@ -65,9 +67,11 @@ bool IsPartialNode(const void *primitive) {
}
int GetPartialGraphIndex(const void *primitive) {
MS_ASSERT(primitive != nullptr);
int index = -1;
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
if (schema_version == SCHEMA_CUR) {
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_PartialFusion() != nullptr);
index = static_cast<const schema::Primitive *>(primitive)->value_as_PartialFusion()->sub_graph_index();
}
#ifdef ENABLE_V0
@ -79,6 +83,7 @@ int GetPartialGraphIndex(const void *primitive) {
}
bool IsWhileNode(const void *primitive) {
MS_ASSERT(primitive != nullptr);
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
if (schema_version == SCHEMA_CUR) {
return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_While;
@ -92,13 +97,16 @@ bool IsWhileNode(const void *primitive) {
}
int GetWhileBodySubgraphIndex(const void *primitive) {
MS_ASSERT(primitive != nullptr);
int index = -1;
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
if (schema_version == SCHEMA_CUR) {
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr);
index = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_While()->body_subgraph_index();
}
#ifdef ENABLE_V0
if (schema_version == SCHEMA_V0) {
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr);
index = reinterpret_cast<const schema::v0::Primitive *>(primitive)->value_as_While()->bodySubgraphIndex();
}
#endif
@ -106,13 +114,16 @@ int GetWhileBodySubgraphIndex(const void *primitive) {
}
int GetWhileCondSubgraphIndex(const void *primitive) {
MS_ASSERT(primitive != nullptr);
int index = -1;
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
if (schema_version == SCHEMA_CUR) {
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr);
index = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_While()->cond_subgraph_index();
}
#ifdef ENABLE_V0
if (schema_version == SCHEMA_V0) {
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr);
index = reinterpret_cast<const schema::v0::Primitive *>(primitive)->value_as_While()->condSubgraphIndex();
}
#endif

View File

@ -128,7 +128,7 @@ int TensorList2TensorListC(TensorList *src, TensorListC *dst) {
return NNACL_OK;
}
void TensorListC2TensorList(TensorListC *src, TensorList *dst) {
int TensorListC2TensorList(TensorListC *src, TensorList *dst) {
dst->set_data_type(static_cast<TypeId>(src->data_type_));
dst->set_format(static_cast<schema::Format>(src->format_));
dst->set_shape(std::vector<int>(1, src->element_num_));
@ -136,11 +136,17 @@ void TensorListC2TensorList(TensorListC *src, TensorList *dst) {
// Set Tensors
for (size_t i = 0; i < src->element_num_; i++) {
if (dst->GetTensor(i) == nullptr) {
MS_LOG(ERROR) << "Tensor i is null ptr";
return RET_NULL_PTR;
}
TensorC2Tensor(&src->tensors_[i], dst->GetTensor(i));
}
dst->set_element_shape(std::vector<int>(src->element_shape_, src->element_shape_ + src->element_shape_size_));
dst->set_max_elements_num(src->max_elements_num_);
return RET_OK;
}
int GenerateMergeSwitchOutTensorC(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs,
@ -189,6 +195,7 @@ int GenerateInTensorC(const OpParameter *const parameter, const std::vector<lite
memset(tensor_list_c, 0, sizeof(TensorListC));
ret = TensorList2TensorListC(tensor_list, tensor_list_c);
if (ret != RET_OK) {
free(tensor_list_c);
return NNACL_ERR;
}
in_tensor_c->push_back(reinterpret_cast<TensorC *>(tensor_list_c));

View File

@ -31,7 +31,7 @@ void FreeTensorListC(TensorListC *tensorListC);
void Tensor2TensorC(Tensor *src, TensorC *dst);
void TensorC2Tensor(TensorC *src, Tensor *dst);
int TensorList2TensorListC(TensorList *src, TensorListC *dst);
void TensorListC2TensorList(TensorListC *src, TensorList *dst);
int TensorListC2TensorList(TensorListC *src, TensorList *dst);
int GenerateMergeSwitchOutTensorC(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs,
std::vector<TensorC *> *out_tensor_c);
int GenerateInTensorC(const OpParameter *const parameter, const std::vector<lite::Tensor *> &inputs,

View File

@ -140,7 +140,7 @@ class MSTensor::Impl {
virtual bool IsDevice() const { return false; }
tensor::MSTensor *lite_tensor() { return lite_tensor_; }
tensor::MSTensor *lite_tensor() const { return lite_tensor_; }
Status set_lite_tensor(tensor::MSTensor *tensor) {
if (tensor == nullptr) {

View File

@ -48,6 +48,7 @@ class RegisterKernelInterface {
public:
static RegisterKernelInterface *Instance();
int Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator);
virtual ~RegisterKernelInterface() = default;
private:
RegisterKernelInterface() = default;

View File

@ -31,6 +31,7 @@ class KernelInterfaceRegistry {
}
int Reg(const std::string &vendor, const int &op_type, kernel::KernelInterfaceCreator creator);
virtual ~KernelInterfaceRegistry() = default;
private:
KernelInterfaceRegistry() = default;

View File

@ -37,6 +37,8 @@ class SchemaRegisterImpl {
GetSchemaDef GetPrimTypeGenFunc() const { return prim_type_gen_; }
virtual ~SchemaRegisterImpl() = default;
private:
std::vector<GetSchemaDef> op_def_funcs_;
GetSchemaDef prim_type_gen_;
@ -45,11 +47,13 @@ class SchemaRegisterImpl {
class SchemaOpRegister {
public:
explicit SchemaOpRegister(GetSchemaDef func) { SchemaRegisterImpl::Instance()->OpPush(func); }
virtual ~SchemaOpRegister() = default;
};
class PrimitiveTypeRegister {
public:
explicit PrimitiveTypeRegister(GetSchemaDef func) { SchemaRegisterImpl::Instance()->SetPrimTypeGenFunc(func); }
virtual ~PrimitiveTypeRegister() = default;
};
} // namespace mindspore::lite::ops

View File

@ -31,6 +31,7 @@ class NPUInsertTransformPass : public NPUBasePass {
name_ = "NPUInsertTransformPass";
}
virtual ~NPUInsertTransformPass() = default;
int Run() override;
private:

View File

@ -36,6 +36,8 @@ class NPUTransformPass : public NPUBasePass {
name_ = "NPUTransformPass";
}
virtual ~NPUTransformPass() = default;
private:
int InsertPreNodes(kernel::LiteKernel *kernel, std::vector<kernel::LiteKernel *> *trans_kernels);

View File

@ -37,7 +37,7 @@ class GraphDefTransform {
virtual ~GraphDefTransform();
virtual int Transform(const converter::Flags &ctx);
void SetGraphDef(schema::MetaGraphT *dst_def);
inline schema::MetaGraphT *GetOutput() { return graph_defT_; }
inline schema::MetaGraphT *GetOutput() const { return graph_defT_; }
protected:
std::vector<schema::CNodeT *> GetGraphNodes();

View File

@ -57,6 +57,7 @@ class RegistryPrimitiveAdjust {
RegistryPrimitiveAdjust(const std::string &key, PrimitiveAdjustCreator creator) {
PrimitiveAdjustRegistry::GetInstance()->InsertPrimitiveAdjustMap(key, creator);
}
virtual ~RegistryPrimitiveAdjust() = default;
};
#define REGIST_PRIMITIVE_ADJUST(type, primitive_adjust_func) \