diff --git a/cmake/package_lite.cmake b/cmake/package_lite.cmake index 055606e95b4..d1eb6f9fdbc 100644 --- a/cmake/package_lite.cmake +++ b/cmake/package_lite.cmake @@ -76,7 +76,8 @@ if (PLATFORM_ARM64) install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so DESTINATION ${LIB_DIR} COMPONENT ${COMPONENT_NAME}) install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.a DESTINATION ${LIB_DIR} COMPONENT ${COMPONENT_NAME}) install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${INC_DIR}/ir/dtype COMPONENT ${COMPONENT_NAME}) - install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${INC_DIR}/schema COMPONENT ${COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "inner" EXCLUDE) + install(FILES ${TOP_DIR}/mindspore/lite/build/schema/model_generated.h DESTINATION ${INC_DIR}/schema COMPONENT ${COMPONENT_NAME}) + install(FILES ${TOP_DIR}/mindspore/lite/build/schema/ops_generated.h DESTINATION ${INC_DIR}/schema COMPONENT ${COMPONENT_NAME}) install(DIRECTORY ${flatbuffers_INC} DESTINATION ${FLATBF_DIR} COMPONENT ${COMPONENT_NAME}) if (ENABLE_TOOLS) install(TARGETS benchmark RUNTIME DESTINATION ${MAIN_DIR}-${COMPONENT_NAME}/benchmark COMPONENT ${COMPONENT_NAME}) @@ -90,7 +91,8 @@ elseif (PLATFORM_ARM32) install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so DESTINATION ${LIB_DIR} COMPONENT ${COMPONENT_NAME}) install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.a DESTINATION ${LIB_DIR} COMPONENT ${COMPONENT_NAME}) install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${INC_DIR}/ir/dtype COMPONENT ${COMPONENT_NAME}) - install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${INC_DIR}/schema COMPONENT ${COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "inner" EXCLUDE) + install(FILES ${TOP_DIR}/mindspore/lite/build/schema/model_generated.h DESTINATION ${INC_DIR}/schema COMPONENT ${COMPONENT_NAME}) + install(FILES ${TOP_DIR}/mindspore/lite/build/schema/ops_generated.h DESTINATION ${INC_DIR}/schema COMPONENT ${COMPONENT_NAME}) install(DIRECTORY ${flatbuffers_INC} DESTINATION ${FLATBF_DIR} COMPONENT ${COMPONENT_NAME}) if (ENABLE_TOOLS) install(TARGETS benchmark RUNTIME DESTINATION ${MAIN_DIR}-${COMPONENT_NAME}/benchmark COMPONENT ${COMPONENT_NAME}) diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index f4806af9b37..62b533ba212 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -93,10 +93,7 @@ include(${TOP_DIR}/cmake/dependency_utils.cmake) include(${TOP_DIR}/cmake/dependency_securec.cmake) include(${TOP_DIR}/cmake/external_libs/flatbuffers.cmake) -set(FBS_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/schema/model.fbs - ${CMAKE_CURRENT_SOURCE_DIR}/schema/ops.fbs - ) +file(GLOB FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}/schema/*.fbs) ms_build_flatbuffers_lite(FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}/schema/ fbs_src ${CMAKE_BINARY_DIR}/schema "") ms_build_flatbuffers_lite(FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}/schema/ fbs_inner_src ${CMAKE_BINARY_DIR}/schema/inner "inner") diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 70ab58645a7..9a951da9c0d 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -18,6 +18,11 @@ include "ops.fbs"; namespace mindspore.schema; +// This corresponds to the version. +file_identifier "MSL1"; +// File extension of any written files. +file_extension "ms"; + enum NodeType: int { ValueNode, // const Parameter, // var diff --git a/mindspore/lite/schema/model_v0.fbs b/mindspore/lite/schema/model_v0.fbs new file mode 100644 index 00000000000..a0693230136 --- /dev/null +++ b/mindspore/lite/schema/model_v0.fbs @@ -0,0 +1,282 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +include "ops_v0.fbs"; + +namespace mindspore.schema.v0; + +enum NodeType: int { + ValueNode, // const + Parameter, // var + CNode // op +} + +table QuantParam { + scale: double; + zeroPoint: int; + min: double = 0; + max: double = 0; + narrowRange: bool = true; + numBits: int = 8; + inited: bool = false; + varCorr: float = 1; + meanCorr: float = 0; + dstDtype: int = 32; +} + +table Tensor { + nodeType: NodeType; + // data type + dataType: int; + // shape + dims: [int]; + format: Format; + refCount: int; + offset: int; + data: [ubyte]; + quantParams: [QuantParam]; + quantClusters: [float]; +} + +union PrimitiveType { + Concat, + SoftMax, + Activation, + Conv2D, + FusedBatchNorm, + BatchNorm, + BiasAdd, + Pooling, + ROIPooling, + DepthwiseConv2D, + DeDepthwiseConv2D, + Resize, + DetectionPostProcess, + FullConnection, + Mean, + DeConv2D, + Scale, + Reshape, + Eltwise, + NetOutput, + Add, + Sub, + MatMul, + StridedSlice, + Power, + Slice, + Stack, + Mul, + RealDiv, + Pad, + Maximum, + Minimum, + PReLU, + LeakyReLU, + ArgMax, + ArgMin, + Exp, + Crop, + Range, + Rsqrt, + ExpandDims, + Tile, + Cast, + Shape, + Nchw2Nhwc, + Nhwc2Nchw, + QuantDTypeCast, + Split, + Permute, + FakeQuantWithMinMaxVars, + Equal, + Less, + Greater, + NotEqual, + LessEqual, + GreaterEqual, + Min, + Floor, + Abs, + Neg, + Cos, + Sin, + Sqrt, + Square, + Constant, + Log, + Tan, + Atan, + Asin, + Clip, + Transpose, + Squeeze, + Unsqueeze, + Upsample, + Dropout, + Broadcast, + BroadcastTo, + Lrn, + ZerosLike, + TopK, + SpaceToDepth, + SpaceToBatch, + SparseToDense, + ReverseSequence, + Rank, + Gather, + GatherNd, + Fill, + Elu, + DepthToSpace, + BatchToSpace, + AddN, + Ceil, + EmbeddingLookup, + EmbeddingLookupSparse, + FloorDiv, + FloorMod, + L2Norm, + LocalResponseNormalization, + MatrixDiag, + Reduce, + Reverse, + Round, + Select, + Scatter, + ScatterND, + ConstantOfShape, + Unique, + Unstack, + LogicalAnd, + LogicalOr, + LogicalXor, + LogicalNot, + OnnxInt8Quantize, + OnnxInt8Dequantize, + FakeQuantWithMinMax, + FakeQuantWithMinMaxPerChannel, + BatchNormFold, + MulFold, + AddFold, + SquaredDifference, + Flatten, + FlattenGrad, + TupleGetItem, + Div, + Where, + OneHot, + Lstm, + Conv2DGradFilter, + Conv2DGradInput, + PoolingGrad, + BNGrad, + Assign, + ApplyMomentum, + BiasGrad, + SoftmaxCrossEntropy, + AddGrad, + SubGrad, + MulGrad, + DivGrad, + PowerGrad, + ActivationGrad, + PriorBox, + SpaceToBatchND, + Depend, + Return, + MakeTuple, + ToFormat, + Proposal, + Custom, + BlackBox, + NegGrad, + LogGrad, + BatchToSpaceND, + LshProjection, + HashtableLookup, + SkipGram, + DeConv2DGradFilter, + CustomPredict, + CustomNormalize, + CustomExtractFeatures, + AudioSpectrogram, + Mfcc, + Rfft, + FftReal, + FftImag, + Sgd, + Adam, + GroupConv2DGradInput, + Loop, + NonMaxSuppression, + InstanceNorm, + Identity, + LayerNorm, + While, + ControlDepend, + UnsortedSegmentSum, + AssignAdd, + OnesLike, + BinaryCrossEntropyGrad, + BinaryCrossEntropy, + LpNormalization, + DropoutGrad, + MaximumGrad, + MinimumGrad +} + +enum QuantType: int { + QUANT_NONE, + AwareTraining, + WeightQuant, + PostTraining +} + +table Primitive { + value: PrimitiveType; +} + +table CNode { + name: string; + nodeType: NodeType = CNode; + primitive: Primitive; + inputIndex: [uint]; + outputIndex: [uint]; + quantType: QuantType = QUANT_NONE; +} + +table SubGraph { + name:string; + inputIndices: [uint]; + outputIndices: [uint]; + nodeIndices: [uint]; + tensorIndices: [uint]; +} + +table MetaGraph { + name: string; + version: string; + fmkType: int; // 0:tf,1:caffe + inputIndex: [uint]; + outputIndex: [uint]; + mempoolSize: uint; + nodes: [CNode]; + allTensors: [Tensor]; // weight + input + output + subGraph : [SubGraph]; +} + +root_type MetaGraph; diff --git a/mindspore/lite/schema/ops_v0.fbs b/mindspore/lite/schema/ops_v0.fbs new file mode 100644 index 00000000000..a4d080e0898 --- /dev/null +++ b/mindspore/lite/schema/ops_v0.fbs @@ -0,0 +1,1145 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace mindspore.schema.v0; + +enum ResizeMethod: byte { + UNKNOW = -1, + LINEAR = 0, + NEAREST = 1, + CUBIC = 2 +} + +enum CoordinateTransformMode: byte { + COMMON = 0, + HALF_PIXEL = 1, + PYTORCH_HALF_PIXEL = 2, + TF_HALF_PIXEL = 3, + TF_CROP_AND_RESIZE = 4, + ALIGN_CORNERS = 5, + ASYMMETRIC = 6, + ALIGN_CORNERS_WITH_HALF_PIEXL = 7 +} + +enum NearestMode : byte { + NORMAL = 0, + ROUND_HALF_DOWN = 1, + ROUND_HALF_UP = 2, + FLOOR = 3, + CEIL = 4 +} + +enum Format : int { + NCHW = 0, + NHWC, + NHWC4, + HWKC, + HWCK, + KCHW, + CKHW, + KHWC, + CHWK, + HW, + HW4, + NC, + NC4, + NC4HW4 = 100, + NUM_OF_FORMAT +} + +enum ActivationType : byte { + NO_ACTIVATION = 0, + RELU = 1, + SIGMOID = 2, + RELU6 = 3, + ELU = 4, + LEAKY_RELU = 5, + ABS = 6, + RELU1 = 7, + SOFTSIGN = 8, + SOFTPLUS = 9, + TANH = 10, + SELU = 11, + HSWISH = 12, + HSIGMOID = 13, + THRESHOLDRELU = 14, + LINEAR = 15, + HARD_TANH = 16, + SIGN = 17, + SWISH = 18, + UNKNOW = 19 +} +enum ActivationGradType : byte { + NO_ACTIVATION = 0, + RELU = 1, + SIGMOID = 2, + RELU6 = 3, + ELU = 4, + LEAKY_RELU = 5, + ABS = 6, + RELU1 = 7, + SOFTSIGN = 8, + SOFTPLUS = 9, + TANH = 10, + SELU = 11, + HSWISH = 12, + HSIGMOID = 13, + THRESHOLDRELU = 14, + LINEAR = 15, + UNKNOW = 16 +} +enum ReduceType : byte { + REDUCE_MAX = 0, + REDUCE_MEAN = 1, + REDUCE_ALL = 2, + REDUCE_ANY = 3, + REDUCE_LOG_SUM_EXP = 4, + REDUCE_PROD = 5, + REDUCE_SUM = 6, + UNKNOW = 7 +} + +enum PoolMode : byte { + MAX_POOLING = 0, + MEAN_POOLING = 1, +} + +enum EltwiseMode : byte { + PROD = 0, + SUM = 1, + MAXIMUM = 2, + UNKNOW = 3 +} + +enum PadMode : byte { + NOTSET = 0, + SAME_UPPER = 1, + VALID = 2, + CAFFE = 4, + SAME_LOWER = 5 +} + +enum RoundMode : byte { + FLOOR = 0, + CEIL = 1 +} + +enum PaddingMode : byte { + CONSTANT = 0, + REFLECT = 1, + SYMMETRIC = 2, + MODE_RESERVED = 3 +} + +enum LshProjectionType : byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2 +} + +table Pad { + paddings: [int]; + paddingMode: PaddingMode; + constantValue: float; +} + +table Maximum { +} + +table Minimum { +} + +table Flatten { +} +table FlattenGrad { +} +table Concat { + axis: int; + n: int; +} + +table SoftMax { + axis: int = -1; +} + +table Activation { + type: ActivationType = 0; + alpha: float = 0.2; + min_val: float = -1.0; + max_val: float = 1.0; +} +table ActivationGrad { + type: ActivationType = 0; + alpha: float = 0.2; +} + + +table Conv2D { + format: Format = 0; + group: int; + channelIn: int; + channelOut: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + activationType: ActivationType = 0; +} + +table Conv2DGradFilter { + format: Format = 0; + group: int; + channelIn: int; + channelOut: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + filter_shape: [int]; + activationType: ActivationType = 0; +} + +table Conv2DGradInput { + format: Format = 0; + group: int; + channelIn: int; + channelOut: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + input_shape: [int]; + activationType: ActivationType = 0; +} + +table GroupConv2DGradInput { + format: Format = 0; + group: int; + channelIn: int; + channelOut: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + input_shape: [int]; + activationType: ActivationType = 0; +} + +table FusedBatchNorm { + epsilon: float = 0.00001; // eg. epsilon=0.001 + momentum: float = 0.9; + spatial: int = 1; +} + +table BatchNorm { + epsilon: float = 0.00001; // eg. epsilon=0.001 +} + +table BiasGrad { + axis: [int]; +} + + +table SoftmaxCrossEntropy { + axis: [int]; +} + +table make_tuple { +} + + +table PoolingGrad { + format: Format = 0; + poolingMode: PoolMode; + global: bool = false; + windowW: int; + windowH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + roundMode: RoundMode; +} +table Shape { +} + +table ConstantOfShape{ + dataType: int; + value: [float]; +} + +table Nchw2Nhwc { + +} + +table Nhwc2Nchw { + +} + +table FakeQuantWithMinMaxVars { + narrowRange: bool; + numBits: int; +} + +table BiasAdd { + axis: [int]; +} + +table ROIPooling { + pooledH: int; + pooledW: int; + scale: float; +} + +table Pooling { + format: Format = 0; + poolingMode: PoolMode; + global: bool = false; + windowW: int; + windowH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + roundMode: RoundMode; + activationType: ActivationType = 0; + avgMode: int = 0; +} + +table DepthwiseConv2D { + format: Format = 0; + channelIn: int; + channelMultiplier: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + activationType: ActivationType = 0; +} + +table DeDepthwiseConv2D { + format: Format = 0; + channelIn: int; + channelMultiplier: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + activationType: ActivationType = 0; +} + + +table Resize { + format: Format = 0; + method: ResizeMethod; + newHeight: long; + newWidth: long; + alignCorners: bool = false; // DEPRECATED IN FUTURE: use 'coordinateTransformMode' instead. + preserveAspectRatio: bool = false; + coordinateTransformMode : CoordinateTransformMode; + cubicCoeff : float; + excludeOutside : int; + extrapolationValue : float = 0; + nearestMode : NearestMode; +} + +table DetectionPostProcess { + format: Format = 0; + inputSize: int; + hScale: float; + wScale: float; + xScale: float; + yScale: float; + NmsIouThreshold: float; + NmsScoreThreshold: float; + MaxDetections: long; + DetectionsPerClass: long; + MaxClassesPerDetection: long; + NumClasses: long; + UseRegularNms: bool; + OutQuantized: bool; +} + +table FullConnection { + hasBias: bool; + axis: int; + useAxis: bool; + activationType: ActivationType = 0; +} + +// Mean(input_tensor, axis, keep_dims) +table Mean { + axis: [int]; + keepDims: bool = false; +} + +table DeConv2D { + format: Format = 0; + group: int; + channelIn: int; + channelOut: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + activationType: ActivationType = 0; +} + +table DeConv2DGradFilter { + format: Format = 0; + group: int; + channelIn: int; + channelOut: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + activationType: ActivationType = 0; +} + +table BNGrad { + eps: float; + momentum: float; +} + +table Scale { + axis: int; + activationType: ActivationType = 0; +} + +table Eltwise { + mode: EltwiseMode; +} + +table Add { + activationType: ActivationType = 0; +} + +table Sub { + activationType: ActivationType = 0; +} + +table Mul { + activationType: ActivationType = 0; +} + +table Div { + activationType: ActivationType = 0; +} + +table AddGrad { +} + +table SubGrad { +} + +table MulGrad { +} + +table DivGrad { +} +table RealDiv { +} + +table Rsqrt { +} + +table Equal { +} + +table Less { +} + +table Greater { +} + +table NotEqual { +} + +table LessEqual { +} + +table GreaterEqual { +} + +table Min { +} + +table Slice { + format: Format = 0; + axes: [int]; + begin: [int]; + size: [int]; +} + +table Floor { +} + +table Abs { +} + +table Neg { +} + +table NegGrad { +} + +table Exp { + base : float = -1.0; + scale : float = 1.0; + shift : float = 0.0; +} + +table Cos { +} + +table Sin { +} + +table Sqrt { +} + +table Square { +} + +table Ceil { +} + +table Log { +} + +table LogGrad { +} + +table Tan { +} + +table Atan { +} + +table Asin { +} + +table Reshape { + format: Format = 0; + shape: [long]; +} + +table Power { + power: float; + scale: float; + shift: float; +} +table PowerGrad { + power: float; + scale: float; + shift: float; +} +table ArgMax { + axis: int; + outMaxValue: bool; + topK: int = 1; + keepDims: bool; + axisType: int; +} + +table ArgMin { + axis: int; + outMaxValue: bool; + topK: int = 1; + keepDims: bool; + axisType: int; +} + +table NetOutput { +} + +table MatMul { + broadcast : bool = false; + transposeA : bool = false; + transposeB : bool = false; +} + +table PReLU { + channelShared : bool = false; + slope: [float]; +} + +table LeakyReLU { + negativeSlope: float; +} + +table StridedSlice { + beginMask: int; + endMask: int; + ellipsisMask: int; + newAxisMask: int; + shrinkAxisMask: int; + begin: [int]; + end: [int]; + stride: [int]; + isScale: [int]; +} + +table Stack { + axis: int; + n: int; + isScale: [int]; +} + +table Range { + dType: int; + start: int; + limit: int; + delta: int; +} + +table ExpandDims { + dim: int; +} + +table Tile { + multiples: [int]; + dims: [int]; +} + +table Cast { + srcT: int; + dstT: int; +} + +table QuantDTypeCast { + srcT: int; + dstT: int; +} + +table Split { + numberSplit: int; + sizeSplits: [int]; + splitDim: int; +} + +table Crop { + axis : long; + offsets : [long]; +} + +table Permute { + order: [long]; +} + +table Clip { + max: float; + min: float; +} + +table Constant { +} + + +table Elu { + alpha: float = 1.0; +} + +table Broadcast { +} + +table BroadcastTo { + dst_shape: [int]; +} + +table Lrn { + alpha: float = 0.0001; + beta: float = 0.75; + bias: float = 1.0; + size: int; +} + +enum ReduceMode : byte { + ReduceMean = 0, + ReduceMax = 1, + ReduceMin = 2, + ReduceProd = 3, + ReduceSum = 4, + ReduceSumSquare = 5, + ReduceASum = 6 +} + +table Reduce { + axes: [int]; + keepDims: int; + mode: ReduceMode; + reduceToEnd: bool = false; + coeff: float = 1.0; +} + +table Transpose { + perm: [int]; + conjugate: bool = false; +} + +table Squeeze { + axis: [int]; +} + +table Unsqueeze { + axis: [int]; +} + +table Upsample { + mode: string; + scales: [float]; +} + +table Dropout { + ratio : float = 0.5; +} + +table LocalResponseNormalization { + depth_radius: int; + bias: float; + alpha: float; + beta: float; +} + +table ZerosLike { +} + +table TopK { + k : int; + sorted : bool = true; +} + +table SpaceToDepth { + blockSize : int; + format: Format = 0; +} + +table SpaceToBatch { + blockShape : [int]; + paddings : [int]; +} + +table SparseToDense { + validateIndices: bool; +} + +table ReverseSequence { + seqAxis: int; + batchAxis: int; +} + +table Rank { +} + + +table Gather { + axis: int; + batchDims: int; +} + +table GatherNd { + batchDims: int; +} + +table Fill { + dims: [int]; +} + +table DepthToSpace { + blockSize: int; + format: Format = 0; +} + + +table BatchToSpace { + blockShape: [int]; + crops: [int]; +} + +table BatchToSpaceND { + blockShape: [int]; + crops: [int]; +} + +table AddN { + N: int; +} + + +table EmbeddingLookup { + maxNorm: float = 0.0; +} + +table EmbeddingLookupSparse { + spIds: [int]; + spWeights: [float]; + //combiner: Combiner=0; + maxNortm: float; +} + +table FloorDiv { +} + +table FloorMod { +} + +table L2Norm { + axis: [int]; + epsilon: float; + activationType: ActivationType = 0; +} + +table LogicalAnd { +} + +table LogicalOr { +} + +table LogicalXor { +} + +table LogicalNot { +} + +table MatrixDiag { + k: int; + numRows: int; + numCols: int; + paddingValue: float; +} + +table Select { +} + +table TfReduce { + type: ReduceType = 7; +} + +table Reverse { + axis: [int]; +} + +table Round { +} + +table Scatter { +} + +table ScatterND { +} + +table Unique { + outType: int; +} + +table Unstack { + num: int; + axis: int; +} + +table OnnxInt8Quantize { +} + +table OnnxInt8Dequantize { +} + +table FakeQuantWithMinMax { +} + +table FakeQuantWithMinMaxPerChannel { +} + +table BatchNormFold { +} + +table MulFold { +} + +table AddFold { +} + +table SquaredDifference { +} + +table TupleGetItem { +} + +table ApplyMomentum { + gradientScale: float; + useNesterov: bool; +} + +table Sgd { + weightDecay: float; + dampening: float; + useNesterov: bool; +} + +table Adam { + useNesterov: bool; +} + +table Assign { +} + +table AssignAdd { +} + +table Where{ + condition: [bool]; +} + +table OneHot { + axis: int; +} + +table Lstm{ + bidirection: bool = false; +} + +table PriorBox { + min_sizes: [int]; + max_sizes: [int]; + aspect_ratios: [float]; + variances: [float]; + image_size_w: int; + image_size_h: int; + step_w: float; + step_h: float; + clip: bool = true; + flip: bool = true; + offset: float; +} + +table SpaceToBatchND { + blockShape : [int]; + paddings : [int]; +} + +table MakeTuple { +} + +table ToFormat { + srcT: int; + dstT: int; +} + + +table Depend { +} + +table ControlDepend { +} + +table Return { +} + +table Proposal { + feat_stride : float; + base_size : float; + min_size : float; + ratio : [float]; + scale : [float]; + pre_nms_topn : int; + post_nms_topn : int; + nms_thresh : float; +} + +table Custom { + custom : [ubyte]; +} + + +table BlackBox { + id : string; + size : int; + address : [ubyte]; +} + +table LshProjection { + type : LshProjectionType; +} + +table HashtableLookup { +} + +table SkipGram { + includeAllGrams : bool; + maxSkipSize : int; + ngramSize : int; +} + +table CustomPredict { + outputNum : int; + weightThreshold : float; +} + +table CustomNormalize { +} + +table CustomExtractFeatures { +} + +table AudioSpectrogram { + windowSize : int; + stride : int; + magSquare : bool; +} + +table Mfcc { + freqUpperLimit : float; + freqLowerLimit : float; + filterBankChannelNum : int; + dctCoeffNum : int; +} + +table Rfft { + fftLength : int; +} + +table FftReal { +} + +table FftImag { +} + +table DropoutGrad { + ratio : float = 0.5; +} + +table MaximumGrad { +} + +table MinimumGrad { +} + +table NonMaxSuppression { + centerPointBox : int = 0; +} + +table InstanceNorm { + epsilon : float = 0.00001; +} + +table Loop { + subGraphIndex : int; +} + +table Identity { +} + +table LayerNorm { + normalizedShape : [int]; + epsilon : float = 0.00001; + elementwiseAffine : bool; +} + +table While { + condSubgraphIndex : int; + bodySubgraphIndex : int; +} + +table UnsortedSegmentSum { + numSegments : int; +} + +table OnesLike { + +} + +table BinaryCrossEntropy { + reduction : int = 1; +} + +table BinaryCrossEntropyGrad { + reduction : int = 1; +} + +table LpNormalization { + axis : int; + p : int; +} diff --git a/mindspore/lite/src/common/common.h b/mindspore/lite/src/common/common.h index 9664dea33d4..5cfb5671c6a 100644 --- a/mindspore/lite/src/common/common.h +++ b/mindspore/lite/src/common/common.h @@ -32,7 +32,7 @@ enum CHWK_SHAPE { CHWK_C = 0, CHWK_H = 1, CHWK_W = 2, CHWK_K = 3 }; enum KHWC_SHAPE { KHWC_K = 0, KHWC_H = 1, KHWC_W = 2, KHWC_C = 3 }; enum CHW_SHAPE { CHW_C = 0, CHW_H = 1, CHW_W = 2 }; enum HWC_SHAPE { HWC_H = 0, HWC_W = 1, HWC_C = 2 }; -enum SCHEMA_VERSION { SCHEMA_CUR = 0 }; +enum SCHEMA_VERSION { SCHEMA_INVALID = -1, SCHEMA_CUR = 0, SCHEMA_V0 = 1 }; static constexpr int kNCHWDimNumber = 4; static constexpr int kNHWCDimNumber = 4; diff --git a/mindspore/lite/src/model_common.cc b/mindspore/lite/src/model_common.cc index 553bbd1ec0d..87428952cc7 100644 --- a/mindspore/lite/src/model_common.cc +++ b/mindspore/lite/src/model_common.cc @@ -51,14 +51,18 @@ int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model) { int VersionVerify(flatbuffers::Verifier *verify) { if (schema::VerifyMetaGraphBuffer(*verify)) { return SCHEMA_VERSION::SCHEMA_CUR; + } else if (schema::v0::VerifyMetaGraphBuffer(*verify)) { + return SCHEMA_VERSION::SCHEMA_V0; } - return -1; + return SCHEMA_VERSION::SCHEMA_INVALID; } const void *GetMetaGraphByVerison(const char *buf, const int &schema_version) { MS_ASSERT(buf != nullptr); if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) { return reinterpret_cast(schema::GetMetaGraph(buf)); + } else if (schema_version == SCHEMA_VERSION::SCHEMA_V0) { + return reinterpret_cast(schema::v0::GetMetaGraph(buf)); } return nullptr; } @@ -69,6 +73,9 @@ int GenerateModelByVersion(const void *meta_graph, Model *model, const int &sche if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) { status = GenerateModel(*reinterpret_cast(meta_graph), model, schema_version); + } else if (schema_version == SCHEMA_VERSION::SCHEMA_V0) { + status = GenerateModel( + *reinterpret_cast(meta_graph), model, schema_version); } return status; } diff --git a/mindspore/lite/src/model_common.h b/mindspore/lite/src/model_common.h index 54ad92f6fb2..866254cd991 100644 --- a/mindspore/lite/src/model_common.h +++ b/mindspore/lite/src/model_common.h @@ -22,6 +22,7 @@ #include "include/model.h" #include "include/version.h" #include "schema/model_generated.h" +#include "schema/model_v0_generated.h" #include "src/common/common.h" #ifndef PRIMITIVE_WRITEABLE #include "src/ops/ops_register.h" diff --git a/mindspore/lite/src/train/train_model.cc b/mindspore/lite/src/train/train_model.cc index 2aa81239b12..0cb759be238 100644 --- a/mindspore/lite/src/train/train_model.cc +++ b/mindspore/lite/src/train/train_model.cc @@ -28,7 +28,8 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) { return nullptr; } flatbuffers::Verifier verify((const uint8_t *)model_buf, size); - if (!schema::VerifyMetaGraphBuffer(verify)) { + int schema_version = VersionVerify(&verify); + if (schema_version == -1) { MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; return nullptr; } @@ -45,49 +46,19 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) { } memcpy(model->buf, model_buf, size); model->buf_size_ = size; - auto meta_graph = schema::GetMetaGraph(model->buf); + const void *meta_graph = GetMetaGraphByVerison(model->buf, schema_version); if (meta_graph == nullptr) { - delete model; MS_LOG(ERROR) << "meta_graph is nullptr!"; + delete (model); return nullptr; } - if (meta_graph->name() != nullptr) { - model->name_ = meta_graph->name()->c_str(); - } - if (meta_graph->version() != nullptr) { - model->version_ = meta_graph->version()->c_str(); - } - if (!ConvertNodes(*meta_graph, model)) { - delete model; + int status = GenerateModelByVersion(meta_graph, model, schema_version); + if (status != RET_OK) { + delete (model); + MS_LOG(ERROR) << "fail to generate model"; return nullptr; } - - if (!ConvertTensors(*meta_graph, model)) { - delete model; - return nullptr; - } - - if (meta_graph->subGraph() == nullptr) { - int ret = MetaGraphMappingSubGraph(*meta_graph, model); - if (ret != RET_OK) { - MS_LOG(ERROR) << "converter old version model wrong."; - delete model; - return nullptr; - } - } else { - auto sub_graphs = meta_graph->subGraph(); - auto sub_graph_size = sub_graphs->size(); - for (size_t i = 0; i < sub_graph_size; i++) { - auto sub_graph = sub_graphs->GetAs(i); - int ret = ConvertSubGraph(*sub_graph, model); - if (ret != RET_OK) { - MS_LOG(ERROR) << "converter subgraph wrong."; - delete model; - return nullptr; - } - } - } return model; } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc index 6f1aa1f51ea..cc648fe5b86 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc @@ -355,6 +355,7 @@ TEST_F(NetworkTest, tuning_layer) { flatbuffers::FlatBufferBuilder builder(1024); auto offset = schema::MetaGraph::Pack(builder, meta_graph.get()); builder.Finish(offset); + schema::FinishMetaGraphBuffer(builder, offset); size_t size = builder.GetSize(); const char *content = reinterpret_cast(builder.GetBufferPointer()); std::cout << "build fb size= " << size << std::endl; diff --git a/mindspore/lite/test/ut/src/scheduler_test.cc b/mindspore/lite/test/ut/src/scheduler_test.cc index 1adc4d85a82..e183de5f369 100644 --- a/mindspore/lite/test/ut/src/scheduler_test.cc +++ b/mindspore/lite/test/ut/src/scheduler_test.cc @@ -165,6 +165,7 @@ TEST_F(SchedulerTest, TestConstructSubGraphsTwoBranch) { flatbuffers::FlatBufferBuilder builder(1024); auto offset = mindspore::schema::MetaGraph::Pack(builder, meta_graph.get()); builder.Finish(offset); + mindspore::schema::FinishMetaGraphBuffer(builder, offset); size_t size = builder.GetSize(); const char *content = reinterpret_cast(builder.GetBufferPointer()); auto model = mindspore::lite::Model::Import(content, size); @@ -349,6 +350,7 @@ TEST_F(SchedulerTest, TestConstructSubGraphsThreeBranch) { flatbuffers::FlatBufferBuilder builder(1024); auto offset = mindspore::schema::MetaGraph::Pack(builder, meta_graph.get()); builder.Finish(offset); + mindspore::schema::FinishMetaGraphBuffer(builder, offset); size_t size = builder.GetSize(); const char *content = reinterpret_cast(builder.GetBufferPointer()); auto model = mindspore::lite::Model::Import(content, size); diff --git a/mindspore/lite/tools/common/storage.cc b/mindspore/lite/tools/common/storage.cc index abf20506f82..a94107b6d8a 100644 --- a/mindspore/lite/tools/common/storage.cc +++ b/mindspore/lite/tools/common/storage.cc @@ -27,6 +27,7 @@ int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath flatbuffers::FlatBufferBuilder builder(1024); auto offset = schema::MetaGraph::Pack(builder, &graph); builder.Finish(offset); + schema::FinishMetaGraphBuffer(builder, offset); int size = builder.GetSize(); auto content = builder.GetBufferPointer(); if (content == nullptr) {