[llvm] Add a parser from JSON to TensorSpec

A JSON->TensorSpec utility we will use subsequently to specify
additional outputs needed for certain training scenarios.

Differential Revision: https://reviews.llvm.org/D84976
This commit is contained in:
Mircea Trofin 2020-07-30 12:44:07 -07:00
parent caf002c7be
commit 4b1b109c51
3 changed files with 102 additions and 28 deletions

View File

@ -13,6 +13,7 @@
#ifdef LLVM_HAVE_TF_API
#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/JSON.h"
#include <memory>
#include <vector>
@ -58,6 +59,13 @@ public:
int typeIndex() const { return TypeIndex; }
const std::vector<int64_t> &shape() const { return Shape; }
bool operator==(const TensorSpec &Other) const {
return Name == Other.Name && Port == Other.Port &&
TypeIndex == Other.TypeIndex && Shape == Other.Shape;
}
bool operator!=(const TensorSpec &Other) const { return !(*this == Other); }
private:
TensorSpec(const std::string &Name, int Port, int TypeIndex,
const std::vector<int64_t> &Shape)
@ -73,6 +81,9 @@ private:
std::vector<int64_t> Shape;
};
Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
const json::Value &Value);
class TFModelEvaluator final {
public:
/// The result of a model evaluation. Handles the lifetime of the output
@ -124,17 +135,28 @@ private:
std::unique_ptr<TFModelEvaluatorImpl> Impl;
};
template <> int TensorSpec::getDataType<float>();
template <> int TensorSpec::getDataType<double>();
template <> int TensorSpec::getDataType<int8_t>();
template <> int TensorSpec::getDataType<uint8_t>();
template <> int TensorSpec::getDataType<int16_t>();
template <> int TensorSpec::getDataType<uint16_t>();
template <> int TensorSpec::getDataType<int32_t>();
template <> int TensorSpec::getDataType<uint32_t>();
template <> int TensorSpec::getDataType<int64_t>();
template <> int TensorSpec::getDataType<uint64_t>();
/// List of supported types, as a triple:
/// C++ type
/// short name (for strings, for instance)
/// capitalized short name (for enums, for instance)
#define TFUTILS_SUPPORTED_TYPES(M) \
M(float, float, FLOAT) \
M(double, double, DOUBLE) \
M(int8_t, int8, INT8) \
M(uint8_t, uint8, UINT8) \
M(int16_t, int16, INT16) \
M(uint16_t, uint16, UINT16) \
M(int32_t, int32, INT32) \
M(uint32_t, uint32, UINT32) \
M(int64_t, int64, INT64) \
M(uint64_t, uint64, UINT64)
#define TFUTILS_GETDATATYPE_DEF(T, S, C) \
template <> int TensorSpec::getDataType<T>();
TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_DEF)
#undef TFUTILS_GETDATATYPE_DEF
} // namespace llvm
#endif // LLVM_HAVE_TF_API

View File

@ -13,9 +13,10 @@
#include "llvm/Config/config.h"
#if defined(LLVM_HAVE_TF_API)
#include "llvm/Analysis/Utils/TFUtils.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/Utils/TFUtils.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/raw_ostream.h"
@ -83,6 +84,41 @@ private:
std::vector<TF_Tensor *> Output;
};
Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
const json::Value &Value) {
auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> {
std::string S;
llvm::raw_string_ostream OS(S);
OS << Value;
Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
return None;
};
json::ObjectMapper Mapper(Value);
if (!Mapper)
return EmitError("Value is not a dict");
std::string TensorName;
int TensorPort = -1;
std::string TensorType;
std::vector<int64_t> TensorShape;
if (!Mapper.map<std::string>("name", TensorName))
return EmitError("'name' property not present or not a string");
if (!Mapper.map<std::string>("type", TensorType))
return EmitError("'type' property not present or not a string");
if (!Mapper.map<int>("port", TensorPort))
return EmitError("'port' property not present or not an int");
if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
return EmitError("'shape' property not present or not an int array");
#define PARSE_TYPE(T, S, E) \
if (TensorType == #S) \
return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
TFUTILS_SUPPORTED_TYPES(PARSE_TYPE)
#undef PARSE_TYPE
return None;
}
class TFModelEvaluatorImpl {
public:
TFModelEvaluatorImpl(StringRef SavedModelPath,
@ -249,25 +285,12 @@ void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
return TF_TensorData(Impl->getOutput()[Index]);
}
template <> int TensorSpec::getDataType<float>() { return TF_FLOAT; }
#define TFUTILS_GETDATATYPE_IMPL(T, S, E) \
template <> int TensorSpec::getDataType<T>() { return TF_##E; }
template <> int TensorSpec::getDataType<double>() { return TF_DOUBLE; }
TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL)
template <> int TensorSpec::getDataType<int8_t>() { return TF_INT8; }
template <> int TensorSpec::getDataType<uint8_t>() { return TF_UINT8; }
template <> int TensorSpec::getDataType<int16_t>() { return TF_INT16; }
template <> int TensorSpec::getDataType<uint16_t>() { return TF_UINT16; }
template <> int TensorSpec::getDataType<int32_t>() { return TF_INT32; }
template <> int TensorSpec::getDataType<uint32_t>() { return TF_UINT32; }
template <> int TensorSpec::getDataType<int64_t>() { return TF_INT64; }
template <> int TensorSpec::getDataType<uint64_t>() { return TF_UINT64; }
#undef TFUTILS_GETDATATYPE_IMPL
TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
TFModelEvaluator::~TFModelEvaluator() {}

View File

@ -94,3 +94,32 @@ TEST(TFUtilsTest, EvalError) {
EXPECT_FALSE(ER.hasValue());
EXPECT_FALSE(Evaluator.isValid());
}
TEST(TFUtilsTest, JSONParsing) {
auto Value = json::parse(
R"({"name": "tensor_name",
"port": 2,
"type": "int32",
"shape":[1,4]
})");
EXPECT_TRUE(!!Value);
LLVMContext Ctx;
Optional<TensorSpec> Spec = getTensorSpecFromJSON(Ctx, *Value);
EXPECT_TRUE(Spec.hasValue());
EXPECT_EQ(*Spec, TensorSpec::createSpec<int32_t>("tensor_name", {1, 4}, 2));
}
TEST(TFUtilsTest, JSONParsingInvalidTensorType) {
auto Value = json::parse(
R"(
{"name": "tensor_name",
"port": 2,
"type": "no such type",
"shape":[1,4]
}
)");
EXPECT_TRUE(!!Value);
LLVMContext Ctx;
auto Spec = getTensorSpecFromJSON(Ctx, *Value);
EXPECT_FALSE(Spec.hasValue());
}