forked from mindspore-Ecosystem/mindspore
parent
d328fdcf61
commit
c91b7d15fa
|
@ -24,6 +24,7 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
int GetPrimitiveType(const void *primitive) {
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
if (primitive == nullptr) {
|
||||
return -1;
|
||||
}
|
||||
|
@ -51,6 +52,7 @@ const char *PrimitiveCurVersionTypeName(int type) {
|
|||
int GenPrimVersionKey(int primitive_type, int schema_version) { return primitive_type * 1000 + schema_version; }
|
||||
|
||||
bool IsPartialNode(const void *primitive) {
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
|
||||
if (schema_version == SCHEMA_CUR) {
|
||||
return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_PartialFusion;
|
||||
|
@ -65,9 +67,11 @@ bool IsPartialNode(const void *primitive) {
|
|||
}
|
||||
|
||||
int GetPartialGraphIndex(const void *primitive) {
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
int index = -1;
|
||||
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
|
||||
if (schema_version == SCHEMA_CUR) {
|
||||
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_PartialFusion() != nullptr);
|
||||
index = static_cast<const schema::Primitive *>(primitive)->value_as_PartialFusion()->sub_graph_index();
|
||||
}
|
||||
#ifdef ENABLE_V0
|
||||
|
@ -79,6 +83,7 @@ int GetPartialGraphIndex(const void *primitive) {
|
|||
}
|
||||
|
||||
bool IsWhileNode(const void *primitive) {
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
|
||||
if (schema_version == SCHEMA_CUR) {
|
||||
return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_While;
|
||||
|
@ -92,13 +97,16 @@ bool IsWhileNode(const void *primitive) {
|
|||
}
|
||||
|
||||
int GetWhileBodySubgraphIndex(const void *primitive) {
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
int index = -1;
|
||||
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
|
||||
if (schema_version == SCHEMA_CUR) {
|
||||
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr);
|
||||
index = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_While()->body_subgraph_index();
|
||||
}
|
||||
#ifdef ENABLE_V0
|
||||
if (schema_version == SCHEMA_V0) {
|
||||
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr);
|
||||
index = reinterpret_cast<const schema::v0::Primitive *>(primitive)->value_as_While()->bodySubgraphIndex();
|
||||
}
|
||||
#endif
|
||||
|
@ -106,13 +114,16 @@ int GetWhileBodySubgraphIndex(const void *primitive) {
|
|||
}
|
||||
|
||||
int GetWhileCondSubgraphIndex(const void *primitive) {
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
int index = -1;
|
||||
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
|
||||
if (schema_version == SCHEMA_CUR) {
|
||||
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr);
|
||||
index = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_While()->cond_subgraph_index();
|
||||
}
|
||||
#ifdef ENABLE_V0
|
||||
if (schema_version == SCHEMA_V0) {
|
||||
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr);
|
||||
index = reinterpret_cast<const schema::v0::Primitive *>(primitive)->value_as_While()->condSubgraphIndex();
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -128,7 +128,7 @@ int TensorList2TensorListC(TensorList *src, TensorListC *dst) {
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
void TensorListC2TensorList(TensorListC *src, TensorList *dst) {
|
||||
int TensorListC2TensorList(TensorListC *src, TensorList *dst) {
|
||||
dst->set_data_type(static_cast<TypeId>(src->data_type_));
|
||||
dst->set_format(static_cast<schema::Format>(src->format_));
|
||||
dst->set_shape(std::vector<int>(1, src->element_num_));
|
||||
|
@ -136,11 +136,17 @@ void TensorListC2TensorList(TensorListC *src, TensorList *dst) {
|
|||
|
||||
// Set Tensors
|
||||
for (size_t i = 0; i < src->element_num_; i++) {
|
||||
if (dst->GetTensor(i) == nullptr) {
|
||||
MS_LOG(ERROR) << "Tensor i is null ptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
TensorC2Tensor(&src->tensors_[i], dst->GetTensor(i));
|
||||
}
|
||||
|
||||
dst->set_element_shape(std::vector<int>(src->element_shape_, src->element_shape_ + src->element_shape_size_));
|
||||
dst->set_max_elements_num(src->max_elements_num_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GenerateMergeSwitchOutTensorC(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs,
|
||||
|
@ -189,6 +195,7 @@ int GenerateInTensorC(const OpParameter *const parameter, const std::vector<lite
|
|||
memset(tensor_list_c, 0, sizeof(TensorListC));
|
||||
ret = TensorList2TensorListC(tensor_list, tensor_list_c);
|
||||
if (ret != RET_OK) {
|
||||
free(tensor_list_c);
|
||||
return NNACL_ERR;
|
||||
}
|
||||
in_tensor_c->push_back(reinterpret_cast<TensorC *>(tensor_list_c));
|
||||
|
|
|
@ -31,7 +31,7 @@ void FreeTensorListC(TensorListC *tensorListC);
|
|||
void Tensor2TensorC(Tensor *src, TensorC *dst);
|
||||
void TensorC2Tensor(TensorC *src, Tensor *dst);
|
||||
int TensorList2TensorListC(TensorList *src, TensorListC *dst);
|
||||
void TensorListC2TensorList(TensorListC *src, TensorList *dst);
|
||||
int TensorListC2TensorList(TensorListC *src, TensorList *dst);
|
||||
int GenerateMergeSwitchOutTensorC(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs,
|
||||
std::vector<TensorC *> *out_tensor_c);
|
||||
int GenerateInTensorC(const OpParameter *const parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
|
|
|
@ -140,7 +140,7 @@ class MSTensor::Impl {
|
|||
|
||||
virtual bool IsDevice() const { return false; }
|
||||
|
||||
tensor::MSTensor *lite_tensor() { return lite_tensor_; }
|
||||
tensor::MSTensor *lite_tensor() const { return lite_tensor_; }
|
||||
|
||||
Status set_lite_tensor(tensor::MSTensor *tensor) {
|
||||
if (tensor == nullptr) {
|
||||
|
|
|
@ -48,6 +48,7 @@ class RegisterKernelInterface {
|
|||
public:
|
||||
static RegisterKernelInterface *Instance();
|
||||
int Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator);
|
||||
virtual ~RegisterKernelInterface() = default;
|
||||
|
||||
private:
|
||||
RegisterKernelInterface() = default;
|
||||
|
|
|
@ -31,6 +31,7 @@ class KernelInterfaceRegistry {
|
|||
}
|
||||
|
||||
int Reg(const std::string &vendor, const int &op_type, kernel::KernelInterfaceCreator creator);
|
||||
virtual ~KernelInterfaceRegistry() = default;
|
||||
|
||||
private:
|
||||
KernelInterfaceRegistry() = default;
|
||||
|
|
|
@ -37,6 +37,8 @@ class SchemaRegisterImpl {
|
|||
|
||||
GetSchemaDef GetPrimTypeGenFunc() const { return prim_type_gen_; }
|
||||
|
||||
virtual ~SchemaRegisterImpl() = default;
|
||||
|
||||
private:
|
||||
std::vector<GetSchemaDef> op_def_funcs_;
|
||||
GetSchemaDef prim_type_gen_;
|
||||
|
@ -45,11 +47,13 @@ class SchemaRegisterImpl {
|
|||
class SchemaOpRegister {
|
||||
public:
|
||||
explicit SchemaOpRegister(GetSchemaDef func) { SchemaRegisterImpl::Instance()->OpPush(func); }
|
||||
virtual ~SchemaOpRegister() = default;
|
||||
};
|
||||
|
||||
class PrimitiveTypeRegister {
|
||||
public:
|
||||
explicit PrimitiveTypeRegister(GetSchemaDef func) { SchemaRegisterImpl::Instance()->SetPrimTypeGenFunc(func); }
|
||||
virtual ~PrimitiveTypeRegister() = default;
|
||||
};
|
||||
} // namespace mindspore::lite::ops
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ class NPUInsertTransformPass : public NPUBasePass {
|
|||
name_ = "NPUInsertTransformPass";
|
||||
}
|
||||
|
||||
virtual ~NPUInsertTransformPass() = default;
|
||||
int Run() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -36,6 +36,8 @@ class NPUTransformPass : public NPUBasePass {
|
|||
name_ = "NPUTransformPass";
|
||||
}
|
||||
|
||||
virtual ~NPUTransformPass() = default;
|
||||
|
||||
private:
|
||||
int InsertPreNodes(kernel::LiteKernel *kernel, std::vector<kernel::LiteKernel *> *trans_kernels);
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class GraphDefTransform {
|
|||
virtual ~GraphDefTransform();
|
||||
virtual int Transform(const converter::Flags &ctx);
|
||||
void SetGraphDef(schema::MetaGraphT *dst_def);
|
||||
inline schema::MetaGraphT *GetOutput() { return graph_defT_; }
|
||||
inline schema::MetaGraphT *GetOutput() const { return graph_defT_; }
|
||||
|
||||
protected:
|
||||
std::vector<schema::CNodeT *> GetGraphNodes();
|
||||
|
|
|
@ -57,6 +57,7 @@ class RegistryPrimitiveAdjust {
|
|||
RegistryPrimitiveAdjust(const std::string &key, PrimitiveAdjustCreator creator) {
|
||||
PrimitiveAdjustRegistry::GetInstance()->InsertPrimitiveAdjustMap(key, creator);
|
||||
}
|
||||
virtual ~RegistryPrimitiveAdjust() = default;
|
||||
};
|
||||
|
||||
#define REGIST_PRIMITIVE_ADJUST(type, primitive_adjust_func) \
|
||||
|
|
Loading…
Reference in New Issue