!46174 [MS][LITE] fix shape fusion register bug && get prim type name bug

Merge pull request !46174 from jianghui58/codex_fuzz_master
This commit is contained in:
i-robot 2022-11-29 10:35:09 +00:00 committed by Gitee
commit 8ce39575c7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 17 additions and 6 deletions

View File

@ -119,7 +119,7 @@
#include "nnacl/infer/scatter_nd_update_infer.h"
#include "nnacl/infer/select_infer.h"
#include "nnacl/infer/sgd_infer.h"
#ifdef MSLITE_ENABLE_RUNTIME_PASS
#ifndef RUNTIME_PASS_CLIP
#include "nnacl/infer/shape_fusion_infer.h"
#endif
#include "nnacl/infer/shape_infer.h"
@ -400,7 +400,7 @@ void RegAllInferFunc5() {
// fused operators.
g_inner_op_infer_func[PrimType_Inner_GltextureToOpencl - PrimType_InnerOpMin] = NULL;
g_inner_op_infer_func[PrimType_Inner_Identity - PrimType_InnerOpMin] = NULL;
#ifdef MSLITE_ENABLE_RUNTIME_PASS
#ifndef RUNTIME_PASS_CLIP
g_inner_op_infer_func[PrimType_Inner_ShapeFusion - PrimType_InnerOpMin] = ShapeFusionInferShape;
#endif
g_inner_op_infer_func[PrimType_Inner_ToFormat - PrimType_InnerOpMin] = NULL;

View File

@ -22,11 +22,15 @@
namespace mindspore {
namespace lite {
static std::set<schema::PrimitiveType> tensor_list_ops = {
static std::set<schema::PrimitiveType> kTensorListOps = {
schema::PrimitiveType_TensorListFromTensor, schema::PrimitiveType_TensorListGetItem,
schema::PrimitiveType_TensorListReserve, schema::PrimitiveType_TensorListSetItem,
schema::PrimitiveType_TensorListStack};
static const char *const kInnerOpNames[6] = {
"Inner_ToFormat", "Inner_GltextureToOpencl", "Inner_Identity",
"Inner_ShapeFusion", "Inner_GraphKernel", "Inner_SplitReduceConcatFusion",
};
int GetPrimitiveType(const void *primitive, int schema_version) {
if (primitive == nullptr) {
return -1;
@ -42,7 +46,14 @@ const char *GetPrimitiveTypeName(const void *primitive, int schema_version) {
}
const char *PrimitiveCurVersionTypeName(int type) {
return schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(type));
if (type >= static_cast<int>(schema::PrimitiveType_MIN) && type < static_cast<int>(schema::PrimitiveType_MAX)) {
return schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(type));
} else if (type >= static_cast<int>(schema::PrimitiveType_MAX)) {
if (type >= PrimType_InnerOpMin && type < PrimType_InnerOpMax) {
return kInnerOpNames[type - PrimType_InnerOpMin];
}
}
return "";
}
int GenPrimVersionKey(int primitive_type, int schema_version) { return primitive_type * 1000 + schema_version; }
@ -90,8 +101,8 @@ bool IsCustomNode(const void *primitive, int schema_version) {
bool IsTensorListNode(const void *primitive, int schema_version) {
MS_CHECK_TRUE_MSG(primitive != nullptr, false, "primtive cannot be nullptr");
if (schema_version == SCHEMA_CUR) {
if (tensor_list_ops.find(reinterpret_cast<const schema::Primitive *>(primitive)->value_type()) !=
tensor_list_ops.end()) {
if (kTensorListOps.find(reinterpret_cast<const schema::Primitive *>(primitive)->value_type()) !=
kTensorListOps.end()) {
return true;
}
}