forked from OSchip/llvm-project
[mlgo][nfc] Decouple TensorSpec from tensorflow.
The motivation is twofold: 1) Allow plugging in a different training-time evaluator, e.g. TFLite-based, etc. 2) Allow using TensorSpec for AOT, too, to support evolution: we start by extracting a superset of the features currently supported by a model. For the tensors the model does not support, we just return a valid, but useless, buffer. This makes using a 'smaller' model (less supported tensors) transparent to the compiler. The key is to dimension the buffer appropriately, and we already have TensorSpec modeling that info. The only coupling was due to the reliance of a TF internal API for getting the element size, but for the types we are interested in, `sizeof` is sufficient. A subsequent change will yank out TensorSpec in its own module. Differential Revision: https://reviews.llvm.org/D124045
This commit is contained in:
parent
0c090dcc8a
commit
e4794ff5c6
|
@ -46,23 +46,48 @@ class EvaluationResultImpl;
|
|||
///
|
||||
/// TensorSpec is used to set up a TFModelEvaluator by describing the expected
|
||||
/// inputs and outputs.
|
||||
|
||||
/// Known tensor types. The left part is the C type, the right is a name we
|
||||
/// can use to identify the type (to implement TensorSpec equality checks), and
|
||||
/// to use, if needed, when mapping to an underlying evaluator's type system.
|
||||
/// The main requirement is that the C type we use has the same size and
|
||||
/// encoding (e.g. endian-ness) as the one used by the evaluator.
|
||||
#define SUPPORTED_TENSOR_TYPES(M) \
|
||||
M(float, Float) \
|
||||
M(double, Double) \
|
||||
M(int8_t, Int8) \
|
||||
M(uint8_t, UInt8) \
|
||||
M(int16_t, Int16) \
|
||||
M(uint16_t, UInt16) \
|
||||
M(int32_t, Int32) \
|
||||
M(uint32_t, UInt32) \
|
||||
M(int64_t, Int64) \
|
||||
M(uint64_t, UInt64)
|
||||
|
||||
enum class TensorType {
|
||||
Invalid,
|
||||
#define _TENSOR_TYPE_ENUM_MEMBERS(_, Name) Name,
|
||||
SUPPORTED_TENSOR_TYPES(_TENSOR_TYPE_ENUM_MEMBERS)
|
||||
#undef _TENSOR_TYPE_ENUM_MEMBERS
|
||||
};
|
||||
|
||||
class TensorSpec final {
|
||||
public:
|
||||
template <typename T>
|
||||
static TensorSpec createSpec(const std::string &Name,
|
||||
const std::vector<int64_t> &Shape,
|
||||
int Port = 0) {
|
||||
return TensorSpec(Name, Port, getDataType<T>(), Shape);
|
||||
return TensorSpec(Name, Port, getDataType<T>(), sizeof(T), Shape);
|
||||
}
|
||||
|
||||
const std::string &name() const { return Name; }
|
||||
int port() const { return Port; }
|
||||
int typeIndex() const { return TypeIndex; }
|
||||
TensorType type() const { return Type; }
|
||||
const std::vector<int64_t> &shape() const { return Shape; }
|
||||
|
||||
bool operator==(const TensorSpec &Other) const {
|
||||
return Name == Other.Name && Port == Other.Port &&
|
||||
TypeIndex == Other.TypeIndex && Shape == Other.Shape;
|
||||
return Name == Other.Name && Port == Other.Port && Type == Other.Type &&
|
||||
Shape == Other.Shape;
|
||||
}
|
||||
|
||||
bool operator!=(const TensorSpec &Other) const { return !(*this == Other); }
|
||||
|
@ -70,25 +95,24 @@ public:
|
|||
/// Get the number of elements in a tensor with this shape.
|
||||
size_t getElementCount() const { return ElementCount; }
|
||||
/// Get the size, in bytes, of one element.
|
||||
size_t getElementByteSize() const;
|
||||
size_t getElementByteSize() const { return ElementSize; }
|
||||
|
||||
template <typename T> bool isElementType() const {
|
||||
return getDataType<T>() == TypeIndex;
|
||||
return getDataType<T>() == Type;
|
||||
}
|
||||
|
||||
private:
|
||||
TensorSpec(const std::string &Name, int Port, int TypeIndex,
|
||||
const std::vector<int64_t> &Shape);
|
||||
TensorSpec(const std::string &Name, int Port, TensorType Type,
|
||||
size_t ElementSize, const std::vector<int64_t> &Shape);
|
||||
|
||||
template <typename T> static int getDataType() {
|
||||
llvm_unreachable("Undefined tensor type");
|
||||
}
|
||||
template <typename T> static TensorType getDataType();
|
||||
|
||||
std::string Name;
|
||||
int Port = 0;
|
||||
int TypeIndex = 0;
|
||||
TensorType Type = TensorType::Invalid;
|
||||
std::vector<int64_t> Shape;
|
||||
size_t ElementCount = 0;
|
||||
size_t ElementSize = 0;
|
||||
};
|
||||
|
||||
/// Construct a TensorSpec from a JSON dictionary of the form:
|
||||
|
@ -262,25 +286,9 @@ private:
|
|||
std::unique_ptr<TFModelEvaluatorImpl> Impl;
|
||||
};
|
||||
|
||||
/// List of supported types, as a pair:
|
||||
/// - C++ type
|
||||
/// - enum name (implementation-specific)
|
||||
#define TFUTILS_SUPPORTED_TYPES(M) \
|
||||
M(float, TF_FLOAT) \
|
||||
M(double, TF_DOUBLE) \
|
||||
M(int8_t, TF_INT8) \
|
||||
M(uint8_t, TF_UINT8) \
|
||||
M(int16_t, TF_INT16) \
|
||||
M(uint16_t, TF_UINT16) \
|
||||
M(int32_t, TF_INT32) \
|
||||
M(uint32_t, TF_UINT32) \
|
||||
M(int64_t, TF_INT64) \
|
||||
M(uint64_t, TF_UINT64)
|
||||
|
||||
#define TFUTILS_GETDATATYPE_DEF(T, E) \
|
||||
template <> int TensorSpec::getDataType<T>();
|
||||
|
||||
TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_DEF)
|
||||
#define TFUTILS_GETDATATYPE_DEF(T, Name) \
|
||||
template <> TensorType TensorSpec::getDataType<T>();
|
||||
SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_DEF)
|
||||
|
||||
#undef TFUTILS_GETDATATYPE_DEF
|
||||
} // namespace llvm
|
||||
|
|
|
@ -82,6 +82,33 @@ void serialize(const Message &SE, std::string *OutStr) {
|
|||
*OutStr = SE.SerializeAsString();
|
||||
}
|
||||
}
|
||||
|
||||
int getTFTypeIndex(TensorType TType) {
|
||||
switch (TType) {
|
||||
case TensorType::Double:
|
||||
return TF_DOUBLE;
|
||||
case TensorType::Float:
|
||||
return TF_FLOAT;
|
||||
case TensorType::Int8:
|
||||
return TF_INT8;
|
||||
case TensorType::UInt8:
|
||||
return TF_UINT8;
|
||||
case TensorType::Int16:
|
||||
return TF_INT16;
|
||||
case TensorType::UInt16:
|
||||
return TF_UINT16;
|
||||
case TensorType::Int32:
|
||||
return TF_INT32;
|
||||
case TensorType::UInt32:
|
||||
return TF_UINT32;
|
||||
case TensorType::Int64:
|
||||
return TF_INT64;
|
||||
case TensorType::UInt64:
|
||||
return TF_UINT64;
|
||||
case TensorType::Invalid:
|
||||
llvm_unreachable("Unknown tensor type");
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace llvm {
|
||||
|
@ -105,15 +132,12 @@ private:
|
|||
std::vector<TF_Tensor *> Output;
|
||||
};
|
||||
|
||||
size_t TensorSpec::getElementByteSize() const {
|
||||
return TF_DataTypeSize(static_cast<TF_DataType>(TypeIndex));
|
||||
}
|
||||
|
||||
TensorSpec::TensorSpec(const std::string &Name, int Port, int TypeIndex,
|
||||
const std::vector<int64_t> &Shape)
|
||||
: Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape),
|
||||
TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
|
||||
size_t ElementSize, const std::vector<int64_t> &Shape)
|
||||
: Name(Name), Port(Port), Type(Type), Shape(Shape),
|
||||
ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
|
||||
std::multiplies<int64_t>())) {}
|
||||
std::multiplies<int64_t>())),
|
||||
ElementSize(ElementSize) {}
|
||||
|
||||
Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
|
||||
const json::Value &Value) {
|
||||
|
@ -147,7 +171,7 @@ Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
|
|||
#define PARSE_TYPE(T, E) \
|
||||
if (TensorType == #T) \
|
||||
return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
|
||||
TFUTILS_SUPPORTED_TYPES(PARSE_TYPE)
|
||||
SUPPORTED_TENSOR_TYPES(PARSE_TYPE)
|
||||
#undef PARSE_TYPE
|
||||
return None;
|
||||
}
|
||||
|
@ -390,7 +414,7 @@ TFModelEvaluatorImpl::TFModelEvaluatorImpl(
|
|||
InputSpec.port()};
|
||||
if (!checkReportAndInvalidate(InputFeed[I], InputSpec))
|
||||
return;
|
||||
initInput(I, static_cast<TF_DataType>(InputSpec.typeIndex()),
|
||||
initInput(I, static_cast<TF_DataType>(getTFTypeIndex(InputSpec.type())),
|
||||
InputSpec.shape());
|
||||
}
|
||||
for (size_t I = 0; I < OutputSpecsSize; ++I) {
|
||||
|
@ -496,9 +520,9 @@ TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const {
|
|||
}
|
||||
|
||||
#define TFUTILS_GETDATATYPE_IMPL(T, E) \
|
||||
template <> int TensorSpec::getDataType<T>() { return E; }
|
||||
template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
|
||||
|
||||
TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL)
|
||||
SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)
|
||||
|
||||
#undef TFUTILS_GETDATATYPE_IMPL
|
||||
|
||||
|
|
Loading…
Reference in New Issue