forked from mindspore-Ecosystem/mindspore
!46174 [MS][LITE] fix shape fusion register bug && get prim type name bug
Merge pull request !46174 from jianghui58/codex_fuzz_master
This commit is contained in:
commit
8ce39575c7
|
@ -119,7 +119,7 @@
|
|||
#include "nnacl/infer/scatter_nd_update_infer.h"
|
||||
#include "nnacl/infer/select_infer.h"
|
||||
#include "nnacl/infer/sgd_infer.h"
|
||||
#ifdef MSLITE_ENABLE_RUNTIME_PASS
|
||||
#ifndef RUNTIME_PASS_CLIP
|
||||
#include "nnacl/infer/shape_fusion_infer.h"
|
||||
#endif
|
||||
#include "nnacl/infer/shape_infer.h"
|
||||
|
@ -400,7 +400,7 @@ void RegAllInferFunc5() {
|
|||
// fused operators.
|
||||
g_inner_op_infer_func[PrimType_Inner_GltextureToOpencl - PrimType_InnerOpMin] = NULL;
|
||||
g_inner_op_infer_func[PrimType_Inner_Identity - PrimType_InnerOpMin] = NULL;
|
||||
#ifdef MSLITE_ENABLE_RUNTIME_PASS
|
||||
#ifndef RUNTIME_PASS_CLIP
|
||||
g_inner_op_infer_func[PrimType_Inner_ShapeFusion - PrimType_InnerOpMin] = ShapeFusionInferShape;
|
||||
#endif
|
||||
g_inner_op_infer_func[PrimType_Inner_ToFormat - PrimType_InnerOpMin] = NULL;
|
||||
|
|
|
@ -22,11 +22,15 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
static std::set<schema::PrimitiveType> tensor_list_ops = {
|
||||
static std::set<schema::PrimitiveType> kTensorListOps = {
|
||||
schema::PrimitiveType_TensorListFromTensor, schema::PrimitiveType_TensorListGetItem,
|
||||
schema::PrimitiveType_TensorListReserve, schema::PrimitiveType_TensorListSetItem,
|
||||
schema::PrimitiveType_TensorListStack};
|
||||
|
||||
static const char *const kInnerOpNames[6] = {
|
||||
"Inner_ToFormat", "Inner_GltextureToOpencl", "Inner_Identity",
|
||||
"Inner_ShapeFusion", "Inner_GraphKernel", "Inner_SplitReduceConcatFusion",
|
||||
};
|
||||
int GetPrimitiveType(const void *primitive, int schema_version) {
|
||||
if (primitive == nullptr) {
|
||||
return -1;
|
||||
|
@ -42,7 +46,14 @@ const char *GetPrimitiveTypeName(const void *primitive, int schema_version) {
|
|||
}
|
||||
|
||||
const char *PrimitiveCurVersionTypeName(int type) {
|
||||
return schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(type));
|
||||
if (type >= static_cast<int>(schema::PrimitiveType_MIN) && type < static_cast<int>(schema::PrimitiveType_MAX)) {
|
||||
return schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(type));
|
||||
} else if (type >= static_cast<int>(schema::PrimitiveType_MAX)) {
|
||||
if (type >= PrimType_InnerOpMin && type < PrimType_InnerOpMax) {
|
||||
return kInnerOpNames[type - PrimType_InnerOpMin];
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
int GenPrimVersionKey(int primitive_type, int schema_version) { return primitive_type * 1000 + schema_version; }
|
||||
|
@ -90,8 +101,8 @@ bool IsCustomNode(const void *primitive, int schema_version) {
|
|||
bool IsTensorListNode(const void *primitive, int schema_version) {
|
||||
MS_CHECK_TRUE_MSG(primitive != nullptr, false, "primtive cannot be nullptr");
|
||||
if (schema_version == SCHEMA_CUR) {
|
||||
if (tensor_list_ops.find(reinterpret_cast<const schema::Primitive *>(primitive)->value_type()) !=
|
||||
tensor_list_ops.end()) {
|
||||
if (kTensorListOps.find(reinterpret_cast<const schema::Primitive *>(primitive)->value_type()) !=
|
||||
kTensorListOps.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue