forked from mindspore-Ecosystem/mindspore
!8761 [MSLITE] add tensorflow model parser
From: @zhengjun10 Reviewed-by: @hangangqiang,@HilbertDavid Signed-off-by: @hangangqiang
This commit is contained in:
commit
63f0b2c64a
|
@ -306,6 +306,7 @@ if (ENABLE_CONVERTER)
|
|||
tflite_parser_mid
|
||||
caffe_parser_mid
|
||||
onnx_parser_mid
|
||||
tf_parser_mid
|
||||
graph_pass_mid
|
||||
fusion_mid
|
||||
quantizer_mid
|
||||
|
|
|
@ -61,6 +61,7 @@ add_subdirectory(../anf_exporter anf_exporter)
|
|||
add_subdirectory(parser/caffe)
|
||||
add_subdirectory(parser/tflite)
|
||||
add_subdirectory(parser/onnx)
|
||||
add_subdirectory(parser/tf)
|
||||
add_subdirectory(legacy_optimizer)
|
||||
add_subdirectory(quantizer)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../core mindspore_core)
|
||||
|
@ -111,6 +112,7 @@ endif ()
|
|||
|
||||
file(GLOB PROTO_FILE ""
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/*.proto
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto)
|
||||
ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_FILE})
|
||||
add_library(proto_mid OBJECT ${PROTO_SRCS})
|
||||
|
@ -138,6 +140,7 @@ add_dependencies(converter_lite fbs_inner_src)
|
|||
|
||||
target_link_libraries(converter_lite PRIVATE
|
||||
tflite_parser_mid
|
||||
tf_parser_mid
|
||||
caffe_parser_mid
|
||||
onnx_parser_mid
|
||||
anf_importer_mid
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
file(GLOB_RECURSE TF_SRC_LIST ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
|
||||
|
||||
set_property(SOURCE ${TF_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
|
||||
|
||||
add_library(tf_parser_mid OBJECT ${TF_SRC_LIST})
|
||||
|
||||
add_dependencies(tf_parser_mid proto_mid)
|
|
@ -0,0 +1,62 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "AttrValueProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.framework";
|
||||
|
||||
import "tensor.proto";
|
||||
import "tensor_shape.proto";
|
||||
import "types.proto";
|
||||
|
||||
// Protocol buffer representing the value for an attr used to configure an Op.
|
||||
// Comment indicates the corresponding attr type. Only the field matching the
|
||||
// attr type may be filled.
|
||||
message AttrValue {
|
||||
// LINT.IfChange
|
||||
message ListValue {
|
||||
repeated bytes s = 2; // "list(string)"
|
||||
repeated int64 i = 3 [packed = true]; // "list(int)"
|
||||
repeated float f = 4 [packed = true]; // "list(float)"
|
||||
repeated bool b = 5 [packed = true]; // "list(bool)"
|
||||
repeated DataType type = 6 [packed = true]; // "list(type)"
|
||||
repeated TensorShapeProto shape = 7; // "list(shape)"
|
||||
repeated TensorProto tensor = 8; // "list(tensor)"
|
||||
repeated NameAttrList func = 9; // "list(attr)"
|
||||
}
|
||||
// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc)
|
||||
|
||||
oneof value {
|
||||
bytes s = 2; // "string"
|
||||
int64 i = 3; // "int"
|
||||
float f = 4; // "float"
|
||||
bool b = 5; // "bool"
|
||||
DataType type = 6; // "type"
|
||||
TensorShapeProto shape = 7; // "shape"
|
||||
TensorProto tensor = 8; // "tensor"
|
||||
ListValue list = 1; // any "list(...)"
|
||||
|
||||
// "func" represents a function. func.name is a function's name or
|
||||
// a primitive op's name. func.attr.first is the name of an attr
|
||||
// defined for that function. func.attr.second is the value for
|
||||
// that attr in the instantiation.
|
||||
NameAttrList func = 10;
|
||||
|
||||
// This is a placeholder only used in anf_node_map defined inside a
|
||||
// function. It indicates the attr value will be supplied when
|
||||
// the function is instantiated. For example, let us suppose a
|
||||
// node "N" in function "FN". "N" has an attr "A" with value
|
||||
// placeholder = "foo". When FN is instantiated with attr "foo"
|
||||
// set to "bar", the instantiated node N's attr A will have been
|
||||
// given the value "bar".
|
||||
string placeholder = 9;
|
||||
}
|
||||
}
|
||||
|
||||
// A list of attr names and their values. The whole list is attached
|
||||
// with a string name. E.g., MatMul[T=float].
|
||||
message NameAttrList {
|
||||
string name = 1;
|
||||
map<string, AttrValue> attr = 2;
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "FunctionProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.framework";
|
||||
|
||||
import "attr_value.proto";
|
||||
import "node_def.proto";
|
||||
import "op_def.proto";
|
||||
|
||||
// A library is a set of named functions.
|
||||
message FunctionDefLibrary {
|
||||
repeated FunctionDef function = 1;
|
||||
repeated GradientDef gradient = 2;
|
||||
}
|
||||
|
||||
// A function can be instantiated when the runtime can bind every attr
|
||||
// with a value. When a GraphDef has a call to a function, it must
|
||||
// have binding for every attr defined in the signature.
|
||||
//
|
||||
// TODO(zhifengc):
|
||||
// * device spec, etc.
|
||||
message FunctionDef {
|
||||
// The definition of the function's name, arguments, return values,
|
||||
// attrs etc.
|
||||
OpDef signature = 1;
|
||||
|
||||
// Attributes specific to this function definition.
|
||||
map<string, AttrValue> attr = 5;
|
||||
|
||||
// NOTE: field id 2 deleted on Jan 11, 2016, GraphDef version 21.
|
||||
|
||||
// In both of the following fields, there is the need to specify an
|
||||
// output that is used as either the input to another node (in
|
||||
// `node_def`) or as a return value of the function (in `ret`).
|
||||
// Unlike the NodeDefs in GraphDef, we need to be able to specify a
|
||||
// list in some cases (instead of just single outputs). Also, we
|
||||
// need to be able to deal with lists of unknown length (so the
|
||||
// output index may not be known at function definition time). So
|
||||
// we use the following format instead:
|
||||
// * "fun_in" where "fun_in" is the name of a function input arg in
|
||||
// the `signature` field above. This represents that input, whether
|
||||
// it is a single tensor or a list.
|
||||
// * "fun_in:0" gives the first element of a function input arg (a
|
||||
// non-list input is considered a list of length 1 for these
|
||||
// purposes).
|
||||
// * "node:out" where "node" is the name of a node in `node_def` and
|
||||
// "out" is the name one of its op's output arguments (the name
|
||||
// comes from the OpDef of the node's op). This represents that
|
||||
// node's output, whether it is a single tensor or a list.
|
||||
// Note: We enforce that an op's output arguments are never
|
||||
// renamed in the backwards-compatibility test.
|
||||
// * "node:out:0" gives the first element of a node output arg (a
|
||||
// non-list output is considered a list of length 1 for these
|
||||
// purposes).
|
||||
//
|
||||
// NOT CURRENTLY SUPPORTED (but may be in the future):
|
||||
// * "node:out:-1" gives last element in a node output list
|
||||
// * "node:out:1:" gives a list with all but the first element in a
|
||||
// node output list
|
||||
// * "node:out::-1" gives a list with all but the last element in a
|
||||
// node output list
|
||||
|
||||
// The body of the function. Unlike the NodeDefs in a GraphDef, attrs
|
||||
// may have values of type `placeholder` and the `input` field uses
|
||||
// the "output" format above.
|
||||
|
||||
// By convention, "op" in node_def is resolved by consulting with a
|
||||
// user-defined library first. If not resolved, "func" is assumed to
|
||||
// be a builtin op.
|
||||
repeated NodeDef node_def = 3;
|
||||
|
||||
// A mapping from the output arg names from `signature` to the
|
||||
// outputs from `node_def` that should be returned by the function.
|
||||
map<string, string> ret = 4;
|
||||
}
|
||||
|
||||
// GradientDef defines the gradient function of a function defined in
|
||||
// a function library.
|
||||
//
|
||||
// A gradient function g (specified by gradient_func) for a function f
|
||||
// (specified by function_name) must follow the following:
|
||||
//
|
||||
// The function 'f' must be a numerical function which takes N inputs
|
||||
// and produces M outputs. Its gradient function 'g', which is a
|
||||
// function taking N + M inputs and produces N outputs.
|
||||
//
|
||||
// I.e. if we have
|
||||
// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
|
||||
// then, g is
|
||||
// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
|
||||
// dL/dy1, dL/dy2, ..., dL/dy_M),
|
||||
// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
|
||||
// loss function). dL/dx_i is the partial derivative of L with respect
|
||||
// to x_i.
|
||||
message GradientDef {
|
||||
string function_name = 1; // The function name.
|
||||
string gradient_func = 2; // The gradient function's name.
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "GraphProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.framework";
|
||||
|
||||
import "node_def.proto";
|
||||
import "function.proto";
|
||||
import "versions.proto";
|
||||
|
||||
// Represents the graph of operations
|
||||
message GraphDef {
|
||||
repeated NodeDef node = 1;
|
||||
|
||||
// Compatibility versions of the graph. See core/public/version.h for version
|
||||
// history. The GraphDef version is distinct from the TensorFlow version, and
|
||||
// each release of TensorFlow will support a range of GraphDef versions.
|
||||
VersionDef versions = 4;
|
||||
|
||||
// Deprecated single version field; use versions above instead. Since all
|
||||
// GraphDef changes before "versions" was introduced were forward
|
||||
// compatible, this field is entirely ignored.
|
||||
int32 version = 3 [deprecated = true];
|
||||
|
||||
// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
|
||||
//
|
||||
// "library" provides user-defined functions.
|
||||
//
|
||||
// Naming:
|
||||
// * library.function.name are in a flat namespace.
|
||||
// NOTE: We may need to change it to be hierarchical to support
|
||||
// different orgs. E.g.,
|
||||
// { "/google/nn", { ... }},
|
||||
// { "/google/vision", { ... }}
|
||||
// { "/org_foo/module_bar", { ... }}
|
||||
// map<string, FunctionDefLib> named_lib;
|
||||
// * If node[i].op is the name of one function in "library",
|
||||
// node[i] is deemed as a function call. Otherwise, node[i].op
|
||||
// must be a primitive operation supported by the runtime.
|
||||
//
|
||||
//
|
||||
// Function call semantics:
|
||||
//
|
||||
// * The callee may start execution as soon as some of its inputs
|
||||
// are ready. The caller may want to use Tuple() mechanism to
|
||||
// ensure all inputs are ready in the same time.
|
||||
//
|
||||
// * The consumer of return values may start executing as soon as
|
||||
// the return values the consumer depends on are ready. The
|
||||
// consumer may want to use Tuple() mechanism to ensure the
|
||||
// consumer does not start until all return values of the callee
|
||||
// function are ready.
|
||||
FunctionDefLibrary library = 2;
|
||||
};
|
|
@ -0,0 +1,63 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "NodeProto";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.framework";
|
||||
|
||||
import "attr_value.proto";
|
||||
|
||||
message NodeDef {
|
||||
// The name given to this operator. Used for naming inputs,
|
||||
// logging, visualization, etc. Unique within a single GraphDef.
|
||||
// Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*".
|
||||
string name = 1;
|
||||
|
||||
// The operation name. There may be custom parameters in attrs.
|
||||
// Op names starting with an underscore are reserved for internal use.
|
||||
string op = 2;
|
||||
|
||||
// Each input is "node:src_output" with "node" being a string name and
|
||||
// "src_output" indicating which output tensor to use from "node". If
|
||||
// "src_output" is 0 the ":0" suffix can be omitted. Regular inputs
|
||||
// may optionally be followed by control inputs that have the format
|
||||
// "^node".
|
||||
repeated string input = 3;
|
||||
|
||||
// A (possibly partial) specification for the device on which this
|
||||
// node should be placed.
|
||||
// The expected syntax for this string is as follows:
|
||||
//
|
||||
// DEVICE_SPEC ::= PARTIAL_SPEC
|
||||
//
|
||||
// PARTIAL_SPEC ::= ("/" CONSTRAINT) *
|
||||
// CONSTRAINT ::= ("job:" JOB_NAME)
|
||||
// | ("replica:" [1-9][0-9]*)
|
||||
// | ("task:" [1-9][0-9]*)
|
||||
// | ( ("gpu" | "cpu") ":" ([1-9][0-9]* | "*") )
|
||||
//
|
||||
// Valid values for this string include:
|
||||
// * "/job:worker/replica:0/task:1/gpu:3" (full specification)
|
||||
// * "/job:worker/gpu:3" (partial specification)
|
||||
// * "" (no specification)
|
||||
//
|
||||
// If the constraints do not resolve to a single device (or if this
|
||||
// field is empty or not present), the runtime will attempt to
|
||||
// choose a device automatically.
|
||||
string device = 4;
|
||||
|
||||
// Operation-specific graph-construction-time configuration.
|
||||
// Note that this should include all attrs defined in the
|
||||
// corresponding OpDef, including those with a value matching
|
||||
// the default -- this allows the default to change and makes
|
||||
// NodeDefs easier to interpret on their own. However, if
|
||||
// an attr with a default is not specified in this list, the
|
||||
// default will be used.
|
||||
// The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and
|
||||
// one of the names from the corresponding OpDef's attr field).
|
||||
// The values must have a type matching the corresponding OpDef
|
||||
// attr's type field.
|
||||
// TODO(josh11b): Add some examples here showing best practices.
|
||||
map<string, AttrValue> attr = 5;
|
||||
};
|
|
@ -0,0 +1,157 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "OpDefProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.framework";
|
||||
|
||||
import "attr_value.proto";
|
||||
import "types.proto";
|
||||
|
||||
// Defines an operation. A NodeDef in a GraphDef specifies an Op by
|
||||
// using the "op" field which should match the name of a OpDef.
|
||||
message OpDef {
|
||||
// Op names starting with an underscore are reserved for internal use.
|
||||
// Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*".
|
||||
string name = 1;
|
||||
|
||||
// For describing inputs and outputs.
|
||||
message ArgDef {
|
||||
// Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*".
|
||||
string name = 1;
|
||||
|
||||
// Human readable description.
|
||||
string description = 2;
|
||||
|
||||
// Describes the type of one or more tensors that are accepted/produced
|
||||
// by this input/output arg. The only legal combinations are:
|
||||
// * For a single tensor: either the "type" field is set or the
|
||||
// "type_attr" field is set to the name of an attr with type "type".
|
||||
// * For a sequence of tensors with the same type: the "number_attr"
|
||||
// field will be set to the name of an attr with type "int", and
|
||||
// either the "type" or "type_attr" field will be set as for
|
||||
// single tensors.
|
||||
// * For a sequence of tensors, the "type_list_attr" field will be set
|
||||
// to the name of an attr with type "list(type)".
|
||||
DataType type = 3;
|
||||
string type_attr = 4; // if specified, attr must have type "type"
|
||||
string number_attr = 5; // if specified, attr must have type "int"
|
||||
// If specified, attr must have type "list(type)", and none of
|
||||
// type, type_attr, and number_attr may be specified.
|
||||
string type_list_attr = 6;
|
||||
|
||||
// For inputs: if true, the inputs are required to be refs.
|
||||
// By default, inputs can be either refs or non-refs.
|
||||
// For outputs: if true, outputs are refs, otherwise they are not.
|
||||
bool is_ref = 16;
|
||||
};
|
||||
|
||||
// Description of the input(s).
|
||||
repeated ArgDef input_arg = 2;
|
||||
|
||||
// Description of the output(s).
|
||||
repeated ArgDef output_arg = 3;
|
||||
|
||||
// Description of the graph-construction-time configuration of this
|
||||
// Op. That is to say, this describes the attr fields that will
|
||||
// be specified in the NodeDef.
|
||||
message AttrDef {
|
||||
// A descriptive name for the argument. May be used, e.g. by the
|
||||
// Python client, as a keyword argument name, and so should match
|
||||
// the regexp "[a-z][a-z0-9_]+".
|
||||
string name = 1;
|
||||
|
||||
// One of the type names from attr_value.proto ("string", "list(string)",
|
||||
// "int", etc.).
|
||||
string type = 2;
|
||||
|
||||
// A reasonable default for this attribute if the user does not supply
|
||||
// a value. If not specified, the user must supply a value.
|
||||
AttrValue default_value = 3;
|
||||
|
||||
// Human-readable description.
|
||||
string description = 4;
|
||||
|
||||
// TODO(josh11b): bool is_optional?
|
||||
|
||||
// --- Constraints ---
|
||||
// These constraints are only in effect if specified. Default is no
|
||||
// constraints.
|
||||
|
||||
// For type == "int", this is a minimum value. For "list(___)"
|
||||
// types, this is the minimum length.
|
||||
bool has_minimum = 5;
|
||||
int64 minimum = 6;
|
||||
|
||||
// The set of allowed values. Has type that is the "list" version
|
||||
// of the "type" field above (uses the "list" field of AttrValue).
|
||||
// If type == "type" or "list(type)" above, then the "type" field
|
||||
// of "allowed_values.list" has the set of allowed DataTypes.
|
||||
// If type == "string" or "list(string)", then the "s" field of
|
||||
// "allowed_values.list" has the set of allowed strings.
|
||||
AttrValue allowed_values = 7;
|
||||
}
|
||||
repeated AttrDef attr = 4;
|
||||
|
||||
// Optional deprecation based on GraphDef versions.
|
||||
OpDeprecation deprecation = 8;
|
||||
|
||||
// One-line human-readable description of what the Op does.
|
||||
string summary = 5;
|
||||
|
||||
// Additional, longer human-readable description of what the Op does.
|
||||
string description = 6;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Which optimizations this operation can participate in.
|
||||
|
||||
// True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs)
|
||||
bool is_commutative = 18;
|
||||
|
||||
// If is_aggregate is true, then this operation accepts N >= 2
|
||||
// inputs and produces 1 output all of the same type. Should be
|
||||
// associative and commutative, and produce output with the same
|
||||
// shape as the input. The optimizer may replace an aggregate op
|
||||
// taking input from multiple devices with a tree of aggregate ops
|
||||
// that aggregate locally within each device (and possibly within
|
||||
// groups of nearby devices) before communicating.
|
||||
// TODO(josh11b): Implement that optimization.
|
||||
bool is_aggregate = 16; // for things like add
|
||||
|
||||
// Other optimizations go here, like
|
||||
// can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc.
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Optimization constraints.
|
||||
|
||||
// By default Ops may be moved between devices. Stateful ops should
|
||||
// either not be moved, or should only be moved if that state can also
|
||||
// be moved (e.g. via some sort of save / restore).
|
||||
// Stateful ops are guaranteed to never be optimized away by Common
|
||||
// Subexpression Elimination (CSE).
|
||||
bool is_stateful = 17; // for things like variables, queue
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Non-standard options.
|
||||
|
||||
// By default, all inputs to an Op must be initialized Tensors. Ops
|
||||
// that may initialize tensors for the first time should set this
|
||||
// field to true, to allow the Op to take an uninitialized Tensor as
|
||||
// input.
|
||||
bool allows_uninitialized_input = 19; // for Assign, etc.
|
||||
};
|
||||
|
||||
// Information about version-dependent deprecation of an op
|
||||
message OpDeprecation {
|
||||
// First GraphDef version at which the op is disallowed.
|
||||
int32 version = 1;
|
||||
|
||||
// Explanation of why it was deprecated and what to use instead.
|
||||
string explanation = 2;
|
||||
};
|
||||
|
||||
// A collection of OpDefs
|
||||
message OpList {
|
||||
repeated OpDef op = 1;
|
||||
};
|
|
@ -0,0 +1,29 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "ResourceHandle";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.framework";
|
||||
|
||||
// Protocol buffer representing a handle to a tensorflow resource. Handles are
|
||||
// not valid across executions, but can be serialized back and forth from within
|
||||
// a single run.
|
||||
message ResourceHandleProto {
|
||||
// Unique name for the device containing the resource.
|
||||
string device = 1;
|
||||
|
||||
// Container in which this resource is placed.
|
||||
string container = 2;
|
||||
|
||||
// Unique name of this resource.
|
||||
string name = 3;
|
||||
|
||||
// Hash code for the type of the resource. Is only valid in the same device
|
||||
// and in the same execution.
|
||||
uint64 hash_code = 4;
|
||||
|
||||
// For debug-only, the name of the type pointed to by this handle, if
|
||||
// available.
|
||||
string maybe_type_name = 5;
|
||||
};
|
|
@ -0,0 +1,88 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "TensorProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.framework";
|
||||
|
||||
import "resource_handle.proto";
|
||||
import "tensor_shape.proto";
|
||||
import "types.proto";
|
||||
|
||||
// Protocol buffer representing a tensor.
|
||||
message TensorProto {
|
||||
DataType dtype = 1;
|
||||
|
||||
// Shape of the tensor. TODO(touts): sort out the 0-rank issues.
|
||||
TensorShapeProto tensor_shape = 2;
|
||||
|
||||
// Only one of the representations below is set, one of "tensor_contents" and
|
||||
// the "xxx_val" attributes. We are not using oneof because as oneofs cannot
|
||||
// contain repeated fields it would require another extra set of messages.
|
||||
|
||||
// Version number.
|
||||
//
|
||||
// In version 0, if the "repeated xxx" representations contain only one
|
||||
// element, that element is repeated to fill the shape. This makes it easy
|
||||
// to represent a constant Tensor with a single value.
|
||||
int32 version_number = 3;
|
||||
|
||||
// Serialized raw tensor content from either Tensor::AsProtoTensorContent or
|
||||
// memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation
|
||||
// can be used for all tensor types. The purpose of this representation is to
|
||||
// reduce serialization overhead during RPC call by avoiding serialization of
|
||||
// many repeated small items.
|
||||
bytes tensor_content = 4;
|
||||
|
||||
// Type specific representations that make it easy to create tensor protos in
|
||||
// all languages. Only the representation corresponding to "dtype" can
|
||||
// be set. The values hold the flattened representation of the tensor in
|
||||
// row major order.
|
||||
|
||||
// DT_HALF. Note that since protobuf has no int16 type, we'll have some
|
||||
// pointless zero padding for each value here.
|
||||
repeated int32 half_val = 13 [packed = true];
|
||||
|
||||
// DT_FLOAT.
|
||||
repeated float float_val = 5 [packed = true];
|
||||
|
||||
// DT_DOUBLE.
|
||||
repeated double double_val = 6 [packed = true];
|
||||
|
||||
// DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
|
||||
repeated int32 int_val = 7 [packed = true];
|
||||
|
||||
// DT_STRING
|
||||
repeated bytes string_val = 8;
|
||||
|
||||
// DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real
|
||||
// and imaginary parts of i-th single precision complex.
|
||||
repeated float scomplex_val = 9 [packed = true];
|
||||
|
||||
// DT_INT64
|
||||
repeated int64 int64_val = 10 [packed = true];
|
||||
|
||||
// DT_BOOL
|
||||
repeated bool bool_val = 11 [packed = true];
|
||||
|
||||
// DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real
|
||||
// and imaginary parts of i-th double precision complex.
|
||||
repeated double dcomplex_val = 12 [packed = true];
|
||||
|
||||
// DT_RESOURCE
|
||||
repeated ResourceHandleProto resource_handle_val = 14;
|
||||
|
||||
// DT_VARIANT
|
||||
repeated VariantTensorDataProto variant_val = 15;
|
||||
};
|
||||
|
||||
// Protocol buffer representing the serialization format of DT_VARIANT tensors.
|
||||
message VariantTensorDataProto {
|
||||
// Name of the type of objects being serialized.
|
||||
string type_name = 1;
|
||||
// Portions of the object that are not Tensors.
|
||||
bytes metadata = 2;
|
||||
// Tensors contained within objects being serialized.
|
||||
repeated TensorProto tensors = 3;
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
// Protocol buffer representing the shape of tensors.
|
||||
|
||||
syntax = "proto3";
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "TensorShapeProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.framework";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
// Dimensions of a tensor.
|
||||
message TensorShapeProto {
|
||||
// One dimension of the tensor.
|
||||
message Dim {
|
||||
// Size of the tensor in that dimension.
|
||||
// This value must be >= -1, but values of -1 are reserved for "unknown"
|
||||
// shapes (values of -1 mean "unknown" dimension). Certain wrappers
|
||||
// that work with TensorShapeProto may fail at runtime when deserializing
|
||||
// a TensorShapeProto containing a dim value of -1.
|
||||
int64 size = 1;
|
||||
|
||||
// Optional name of the tensor dimension.
|
||||
string name = 2;
|
||||
};
|
||||
|
||||
// Dimensions of the tensor, such as {"input", 30}, {"output", 40}
|
||||
// for a 30 x 40 2D tensor. If an entry has size -1, this
|
||||
// corresponds to a dimension of unknown size. The names are
|
||||
// optional.
|
||||
//
|
||||
// The order of entries in "dim" matters: It indicates the layout of the
|
||||
// values in the tensor in-memory representation.
|
||||
//
|
||||
// The first entry in "dim" is the outermost dimension used to layout the
|
||||
// values, the last entry is the innermost dimension. This matches the
|
||||
// in-memory layout of RowMajor Eigen tensors.
|
||||
//
|
||||
// If "dim.size()" > 0, "unknown_rank" must be false.
|
||||
repeated Dim dim = 2;
|
||||
|
||||
// If true, the number of dimensions in the shape is unknown.
|
||||
//
|
||||
// If true, "dim.size()" must be 0.
|
||||
bool unknown_rank = 3;
|
||||
};
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/tf/tf_add_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFAddParser::Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model,
|
||||
PrimitiveC *primitiveC, int *output_size) {
|
||||
auto attr = std::make_unique<schema::PrimitiveT>();
|
||||
attr->value.type = schema::PrimitiveType_Add;
|
||||
primitiveC = PrimitiveC::Create(attr.release());
|
||||
MS_LOG(INFO) << "primitive name" << primitiveC->type_name();
|
||||
return RET_OK;
|
||||
}
|
||||
TFNodeRegistrar g_tfAddParser("Add", new TFAddParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFAddParser : public TFNodeParser {
|
||||
public:
|
||||
TFAddParser() = default;
|
||||
~TFAddParser() override = default;
|
||||
|
||||
STATUS Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model,
|
||||
PrimitiveC *primitiveC, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H
|
|
@ -0,0 +1,286 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* distributed under the License is distributed on an AS
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/tf/tf_model_parser.h"
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/converter/parser/tf/tf_util.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include "src/param_value_lite.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
|
||||
const QuantType &quantType) {
|
||||
auto status = ValidateFileStr(modelFile, ".prototxt");
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
if (!TensorFlowUtils::TfReadProtoFromBinary(modelFile.c_str(), tf_graph_def.get())) {
|
||||
MS_LOG(ERROR) << "Open modelFile for TF converter failed!";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
return nullptr;
|
||||
}
|
||||
funcGraphPtr = std::make_shared<FuncGraph>();
|
||||
status = ConvertGraphInputs();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert graph inputs failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
status = ConvertOps();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert ops failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = ConvertGraphOutputs();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert graph outputs failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
return funcGraphPtr;
|
||||
}
|
||||
STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef *node, ParameterPtr parameter) {
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) {
|
||||
tensorflow::AttrValue data_type;
|
||||
tensorflow::DataType type = tensorflow::DT_FLOAT;
|
||||
// datatype
|
||||
if (TensorFlowUtils::FindAttrValue(node, "dtype", &data_type)) {
|
||||
type = data_type.type();
|
||||
}
|
||||
const tensorflow::TensorProto &tensorProto = attr_value.tensor();
|
||||
const tensorflow::TensorShapeProto &tensorShape = tensorProto.tensor_shape();
|
||||
parameter = funcGraphPtr->add_parameter();
|
||||
std::vector<int64_t> shape_vector;
|
||||
int shape_size = 1;
|
||||
shape_vector.resize(tensorShape.dim_size());
|
||||
for (int i = 0; i < tensorShape.dim_size(); i++) {
|
||||
shape_vector[i] = tensorShape.dim(i).size();
|
||||
shape_size *= shape_vector[i];
|
||||
}
|
||||
// convert const to paramter
|
||||
TypePtr ms_data_ype;
|
||||
auto paramValue = std::make_shared<ParamValueLite>();
|
||||
if (type == tensorflow::DT_FLOAT) {
|
||||
ms_data_ype = kFloat32;
|
||||
auto tensor_data = new (std::nothrow) float[shape_size];
|
||||
if (tensorProto.float_val_size() == 1) {
|
||||
float value = tensorProto.float_val(0);
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
tensor_data[i] = value;
|
||||
}
|
||||
}
|
||||
if (tensorProto.tensor_content().size() == shape_size * sizeof(float)) {
|
||||
const auto addr = reinterpret_cast<const float *>(tensorProto.tensor_content().data());
|
||||
auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(float), addr, shape_size * sizeof(float));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
paramValue->set_tensor_addr(tensor_data);
|
||||
paramValue->set_tensor_size(shape_size * sizeof(float));
|
||||
} else if (type == tensorflow::DT_INT32) {
|
||||
ms_data_ype = kInt32;
|
||||
auto tensor_data = new (std::nothrow) int[shape_size];
|
||||
if (tensorProto.int_val_size() == 1) {
|
||||
int value = tensorProto.int_val(0);
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
tensor_data[i] = value;
|
||||
}
|
||||
}
|
||||
if (tensorProto.tensor_content().size() == shape_size * sizeof(int32_t)) {
|
||||
const auto addr = reinterpret_cast<const int32_t *>(tensorProto.tensor_content().data());
|
||||
auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(int32_t), addr, shape_size * sizeof(int32_t));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
paramValue->set_tensor_addr(tensor_data);
|
||||
paramValue->set_tensor_size(shape_size * sizeof(int));
|
||||
} else if (type == tensorflow::DT_BOOL) {
|
||||
ms_data_ype = kFloat32;
|
||||
auto tensor_data = new (std::nothrow) int[shape_size];
|
||||
if (tensorProto.bool_val_size() == 1) {
|
||||
int value = tensorProto.bool_val(0);
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
tensor_data[i] = value;
|
||||
}
|
||||
}
|
||||
paramValue->set_tensor_addr(tensor_data);
|
||||
paramValue->set_tensor_size(shape_size * sizeof(int));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupport dataType," << node->name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(ms_data_ype, shape_vector);
|
||||
parameter->set_abstract(abstract_tensor);
|
||||
parameter->set_name("const_" + std::to_string(anf_node_map.size()) + "_parameter");
|
||||
|
||||
std::vector<int> param_shape;
|
||||
(void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(param_shape),
|
||||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
|
||||
MS_ASSERT(paramValue != nullptr);
|
||||
paramValue->set_tensor_shape(param_shape);
|
||||
paramValue->set_tensor_type(ms_data_ype->type_id());
|
||||
paramValue->set_format(schema::Format::Format_NHWC);
|
||||
paramValue->set_tensor_size(shape_size * sizeof(int));
|
||||
parameter->set_default_param(paramValue);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef *op, const CNodePtr &anf_node, int output_size) {
|
||||
if (output_size == 1) {
|
||||
std::vector<int64_t> shape_vector;
|
||||
anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector));
|
||||
anf_node_map.insert(std::pair(op->name(), anf_node));
|
||||
} else {
|
||||
AbstractBasePtrList abstractList;
|
||||
for (int output_idx = 0; output_idx < output_size; output_idx++) {
|
||||
std::vector<int64_t> shape_vector;
|
||||
abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector));
|
||||
auto tupleGetItemPrimPtr = GetTupleGetItemPrim();
|
||||
if (tupleGetItemPrimPtr == nullptr) {
|
||||
MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto tupleGetItemPrim = NewValueNode(tupleGetItemPrimPtr);
|
||||
auto getItemValue = NewValueNode(MakeValue<int>(output_idx));
|
||||
std::vector<AnfNodePtr> inputs{tupleGetItemPrim, anf_node, getItemValue};
|
||||
CNodePtr getItemCNode = funcGraphPtr->NewCNode(inputs);
|
||||
std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx);
|
||||
getItemCNode->set_fullname_with_scope(output_item_name);
|
||||
anf_node_map.insert(std::pair(output_item_name, getItemCNode));
|
||||
}
|
||||
anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstractList));
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS TFModelParser::ConvertOps() {
|
||||
NoSupportOp::GetInstance()->SetFmkType("TENSORFLOW");
|
||||
STATUS status = RET_OK;
|
||||
|
||||
// redirect identity to it's input0
|
||||
ClipIdentityAndStopGradient();
|
||||
int op_idx = 0;
|
||||
for (int i = 0; i < tf_graph_def->node_size(); i++) {
|
||||
auto node_def = tf_graph_def->mutable_node(i);
|
||||
tf_node_map[node_def->name()] = node_def;
|
||||
auto tf_op_type = node_def->op();
|
||||
if (tf_op_type == "Placeholder" || tf_op_type == "Const") {
|
||||
continue;
|
||||
}
|
||||
auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(tf_op_type);
|
||||
if (node_parser == nullptr) {
|
||||
NoSupportOp::GetInstance()->InsertOp(tf_op_type);
|
||||
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
|
||||
MS_LOG(ERROR) << "cannot find node parser:" << tf_op_type;
|
||||
continue;
|
||||
}
|
||||
PrimitiveC *primitiveC = nullptr;
|
||||
if (status == RET_OK) {
|
||||
int output_size = 1;
|
||||
status = node_parser->Parse(node_def, tf_graph_def, primitiveC, &output_size);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "node " << tf_op_type.c_str() << " parser failed";
|
||||
continue;
|
||||
}
|
||||
std::vector<AnfNodePtr> opInputs = {NewValueNode(std::shared_ptr<PrimitiveC>(primitiveC))};
|
||||
// parse inputs
|
||||
for (int j = 0; j < node_def->input_size(); j++) {
|
||||
auto input_node = tf_node_map[node_def->input(i)];
|
||||
// last node output
|
||||
if (anf_node_map.find(input_node->name()) != anf_node_map.end()) {
|
||||
opInputs.emplace_back(anf_node_map[input_node->name()]);
|
||||
continue;
|
||||
}
|
||||
// const tensor
|
||||
if (input_node->op() == "Const") {
|
||||
ParameterPtr parameter;
|
||||
if (ConvertConstTensor(input_node, parameter) != RET_OK) {
|
||||
MS_LOG(ERROR) << "convert const tensor failed," << input_node->name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
opInputs.emplace_back(parameter);
|
||||
anf_node_map[parameter->fullname_with_scope()] = parameter;
|
||||
continue;
|
||||
}
|
||||
MS_LOG(ERROR) << "node" << node_def->name() << "has inputs neither a node output nor a weight tensor.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto anf_node = funcGraphPtr->NewCNode(opInputs);
|
||||
anf_node->set_fullname_with_scope(tf_op_type + "-" + std::to_string(op_idx++));
|
||||
|
||||
// parse outputs
|
||||
status = ConvertOutputTensor(node_def, anf_node, output_size);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return status;
|
||||
}
|
||||
}
|
||||
// redirect identity to it's input0
|
||||
ClipIdentityAndStopGradient();
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS TFModelParser::ConvertGraphInputs() {
|
||||
for (int i = 0; i < tf_graph_def->node_size(); i++) {
|
||||
auto node_def = tf_graph_def->mutable_node(i);
|
||||
tf_node_map[node_def->name()] = node_def;
|
||||
if (node_def->op() == "Placeholder") {
|
||||
auto parameter = funcGraphPtr->add_parameter();
|
||||
if (ConvertConstTensor(node_def, parameter) != RET_OK) {
|
||||
MS_LOG(ERROR) << "convert const tensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
anf_node_map[node_def->name()] = parameter;
|
||||
graph_input_names.emplace_back(node_def->name());
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS TFModelParser::ConvertGraphOutputs() { return RET_OK; }
|
||||
|
||||
std::string TFModelParser::GetOriginInputName(const tensorflow::NodeDef &node) {
|
||||
if (node.op() != "Identity" && node.op() != "StopGradient") {
|
||||
return node.name();
|
||||
}
|
||||
auto tmpNode = node;
|
||||
while (tmpNode.op() == "Identity" || tmpNode.op() == "StopGradient") {
|
||||
tmpNode = *tf_node_map[tmpNode.input(0)];
|
||||
}
|
||||
return tmpNode.name();
|
||||
}
|
||||
|
||||
void TFModelParser::ClipIdentityAndStopGradient() {
|
||||
for (auto &pair : tf_node_map) {
|
||||
pair.second = tf_node_map[GetOriginInputName(*pair.second)];
|
||||
}
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include "securec/include/securec.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "proto/node_def.pb.h"
|
||||
#include "proto/graph.pb.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFModelParser {
|
||||
public:
|
||||
TFModelParser() = default;
|
||||
~TFModelParser() = default;
|
||||
|
||||
FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType);
|
||||
|
||||
private:
|
||||
STATUS ConvertConstTensor(const tensorflow::NodeDef *op, ParameterPtr parameter);
|
||||
STATUS ConvertOutputTensor(const tensorflow::NodeDef *op, const CNodePtr &anf_node, int output_size);
|
||||
STATUS ConvertOps();
|
||||
STATUS ConvertGraphInputs();
|
||||
STATUS ConvertGraphOutputs();
|
||||
|
||||
std::string GetOriginInputName(const tensorflow::NodeDef &node);
|
||||
|
||||
void ClipIdentityAndStopGradient();
|
||||
|
||||
FuncGraphPtr funcGraphPtr;
|
||||
std::unique_ptr<tensorflow::GraphDef> tf_graph_def;
|
||||
std::map<std::string, const tensorflow::NodeDef *> tf_node_map;
|
||||
std::unordered_map<std::string, AnfNodePtr> anf_node_map;
|
||||
std::vector<std::string> graph_input_names, graphOutputNames;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/tf/tf_util.h"
|
||||
#include "proto/graph.pb.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFNodeParser {
|
||||
public:
|
||||
TFNodeParser() = default;
|
||||
|
||||
virtual ~TFNodeParser() = default;
|
||||
|
||||
virtual STATUS Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model,
|
||||
PrimitiveC *primitiveC, int *output_size) {
|
||||
return RET_OK;
|
||||
}
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* distributed under the License is distributed on an AS
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include <map>
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
TFNodeParserRegistry::~TFNodeParserRegistry() {
|
||||
for (const auto &iter : parsers) {
|
||||
delete iter.second;
|
||||
}
|
||||
this->parsers.clear();
|
||||
}
|
||||
|
||||
TFNodeParserRegistry *TFNodeParserRegistry::GetInstance() {
|
||||
static TFNodeParserRegistry instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
TFNodeParser *TFNodeParserRegistry::GetNodeParser(const std::string &name) {
|
||||
auto it = parsers.find(name);
|
||||
if (it != parsers.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_REGISTRY_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_REGISTRY_H
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFNodeParserRegistry {
|
||||
public:
|
||||
TFNodeParserRegistry() = default;
|
||||
|
||||
virtual ~TFNodeParserRegistry();
|
||||
|
||||
static TFNodeParserRegistry *GetInstance();
|
||||
TFNodeParser *GetNodeParser(const std::string &name);
|
||||
|
||||
std::unordered_map<std::string, TFNodeParser *> parsers;
|
||||
};
|
||||
|
||||
class TFNodeRegistrar {
|
||||
public:
|
||||
TFNodeRegistrar(const std::string &name, TFNodeParser *parser) {
|
||||
TFNodeParserRegistry::GetInstance()->parsers[name] = parser;
|
||||
}
|
||||
~TFNodeRegistrar() = default;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_REGISTRY_H
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/tf/tf_util.h"
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef *nodeDef, std::string attr_name,
|
||||
tensorflow::AttrValue *attr_value) {
|
||||
const google::protobuf::Map<std::string, tensorflow::AttrValue> &attr = nodeDef->attr();
|
||||
const google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr.find(attr_name);
|
||||
if (it != attr.end()) {
|
||||
*attr_value = it->second;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool TensorFlowUtils::TfReadProtoFromBinary(const char *filepath, google::protobuf::Message *message) {
|
||||
std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
|
||||
if (!fs.is_open()) {
|
||||
fprintf(stderr, "open failed %s\n", filepath);
|
||||
return false;
|
||||
}
|
||||
|
||||
google::protobuf::io::IstreamInputStream input(&fs);
|
||||
google::protobuf::io::CodedInputStream codedstr(&input);
|
||||
|
||||
codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);
|
||||
|
||||
bool success = message->ParseFromCodedStream(&codedstr);
|
||||
|
||||
fs.close();
|
||||
|
||||
return success;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_UTIL_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_UTIL_H
|
||||
|
||||
#include <string>
|
||||
#include "proto/node_def.pb.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TensorFlowUtils {
|
||||
public:
|
||||
static bool FindAttrValue(const tensorflow::NodeDef *nodeDef, std::string attr_name,
|
||||
tensorflow::AttrValue *attr_value);
|
||||
|
||||
static bool TfReadProtoFromBinary(const char *filepath, google::protobuf::Message *message);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_UTIL_H
|
|
@ -0,0 +1,66 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "TypesProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.framework";
|
||||
|
||||
// LINT.IfChange
|
||||
enum DataType {
|
||||
// Not a legal value for DataType. Used to indicate a DataType field
|
||||
// has not been set.
|
||||
DT_INVALID = 0;
|
||||
|
||||
// Data types that all computation devices are expected to be
|
||||
// capable to support.
|
||||
DT_FLOAT = 1;
|
||||
DT_DOUBLE = 2;
|
||||
DT_INT32 = 3;
|
||||
DT_UINT8 = 4;
|
||||
DT_INT16 = 5;
|
||||
DT_INT8 = 6;
|
||||
DT_STRING = 7;
|
||||
DT_COMPLEX64 = 8; // Single-precision complex
|
||||
DT_INT64 = 9;
|
||||
DT_BOOL = 10;
|
||||
DT_QINT8 = 11; // Quantized int8
|
||||
DT_QUINT8 = 12; // Quantized uint8
|
||||
DT_QINT32 = 13; // Quantized int32
|
||||
DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops.
|
||||
DT_QINT16 = 15; // Quantized int16
|
||||
DT_QUINT16 = 16; // Quantized uint16
|
||||
DT_UINT16 = 17;
|
||||
DT_COMPLEX128 = 18; // Double-precision complex
|
||||
DT_HALF = 19;
|
||||
DT_RESOURCE = 20;
|
||||
DT_VARIANT = 21; // Arbitrary C++ data types
|
||||
|
||||
// TODO(josh11b): DT_GENERIC_PROTO = ??;
|
||||
// TODO(jeff,josh11b): DT_UINT64? DT_UINT32?
|
||||
|
||||
// Do not use! These are only for parameters. Every enum above
|
||||
// should have a corresponding value below (verified by types_test).
|
||||
DT_FLOAT_REF = 101;
|
||||
DT_DOUBLE_REF = 102;
|
||||
DT_INT32_REF = 103;
|
||||
DT_UINT8_REF = 104;
|
||||
DT_INT16_REF = 105;
|
||||
DT_INT8_REF = 106;
|
||||
DT_STRING_REF = 107;
|
||||
DT_COMPLEX64_REF = 108;
|
||||
DT_INT64_REF = 109;
|
||||
DT_BOOL_REF = 110;
|
||||
DT_QINT8_REF = 111;
|
||||
DT_QUINT8_REF = 112;
|
||||
DT_QINT32_REF = 113;
|
||||
DT_BFLOAT16_REF = 114;
|
||||
DT_QINT16_REF = 115;
|
||||
DT_QUINT16_REF = 116;
|
||||
DT_UINT16_REF = 117;
|
||||
DT_COMPLEX128_REF = 118;
|
||||
DT_HALF_REF = 119;
|
||||
DT_RESOURCE_REF = 120;
|
||||
DT_VARIANT_REF = 121;
|
||||
}
|
||||
// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go)
|
|
@ -0,0 +1,31 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "VersionsProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.framework";
|
||||
|
||||
// Version information for a piece of serialized data
|
||||
//
|
||||
// There are different types of versions for each type of data
|
||||
// (GraphDef, etc.), but they all have the same common shape
|
||||
// described here.
|
||||
//
|
||||
// Each consumer has "consumer" and "min_producer" versions (specified
|
||||
// elsewhere). A consumer is allowed to consume this data if
|
||||
//
|
||||
// producer >= min_producer
|
||||
// consumer >= min_consumer
|
||||
// consumer not in bad_consumers
|
||||
//
|
||||
message VersionDef {
|
||||
// The version of the code that produced this data.
|
||||
int32 producer = 1;
|
||||
|
||||
// Any consumer below this version is not allowed to consume this data.
|
||||
int32 min_consumer = 2;
|
||||
|
||||
// Specific consumer versions which are disallowed (e.g. due to bugs).
|
||||
repeated int32 bad_consumers = 3;
|
||||
};
|
Loading…
Reference in New Issue