forked from OSchip/llvm-project
[llvm] Add a parser from JSON to TensorSpec
A JSON->TensorSpec utility we will use subsequently to specify additional outputs needed for certain training scenarios. Differential Revision: https://reviews.llvm.org/D84976
This commit is contained in:
parent
caf002c7be
commit
4b1b109c51
|
@ -13,6 +13,7 @@
|
|||
|
||||
#ifdef LLVM_HAVE_TF_API
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
@ -58,6 +59,13 @@ public:
|
|||
int typeIndex() const { return TypeIndex; }
|
||||
const std::vector<int64_t> &shape() const { return Shape; }
|
||||
|
||||
bool operator==(const TensorSpec &Other) const {
|
||||
return Name == Other.Name && Port == Other.Port &&
|
||||
TypeIndex == Other.TypeIndex && Shape == Other.Shape;
|
||||
}
|
||||
|
||||
bool operator!=(const TensorSpec &Other) const { return !(*this == Other); }
|
||||
|
||||
private:
|
||||
TensorSpec(const std::string &Name, int Port, int TypeIndex,
|
||||
const std::vector<int64_t> &Shape)
|
||||
|
@ -73,6 +81,9 @@ private:
|
|||
std::vector<int64_t> Shape;
|
||||
};
|
||||
|
||||
Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
|
||||
const json::Value &Value);
|
||||
|
||||
class TFModelEvaluator final {
|
||||
public:
|
||||
/// The result of a model evaluation. Handles the lifetime of the output
|
||||
|
@ -124,17 +135,28 @@ private:
|
|||
std::unique_ptr<TFModelEvaluatorImpl> Impl;
|
||||
};
|
||||
|
||||
template <> int TensorSpec::getDataType<float>();
|
||||
template <> int TensorSpec::getDataType<double>();
|
||||
template <> int TensorSpec::getDataType<int8_t>();
|
||||
template <> int TensorSpec::getDataType<uint8_t>();
|
||||
template <> int TensorSpec::getDataType<int16_t>();
|
||||
template <> int TensorSpec::getDataType<uint16_t>();
|
||||
template <> int TensorSpec::getDataType<int32_t>();
|
||||
template <> int TensorSpec::getDataType<uint32_t>();
|
||||
template <> int TensorSpec::getDataType<int64_t>();
|
||||
template <> int TensorSpec::getDataType<uint64_t>();
|
||||
/// List of supported types, as a triple:
|
||||
/// C++ type
|
||||
/// short name (for strings, for instance)
|
||||
/// capitalized short name (for enums, for instance)
|
||||
#define TFUTILS_SUPPORTED_TYPES(M) \
|
||||
M(float, float, FLOAT) \
|
||||
M(double, double, DOUBLE) \
|
||||
M(int8_t, int8, INT8) \
|
||||
M(uint8_t, uint8, UINT8) \
|
||||
M(int16_t, int16, INT16) \
|
||||
M(uint16_t, uint16, UINT16) \
|
||||
M(int32_t, int32, INT32) \
|
||||
M(uint32_t, uint32, UINT32) \
|
||||
M(int64_t, int64, INT64) \
|
||||
M(uint64_t, uint64, UINT64)
|
||||
|
||||
#define TFUTILS_GETDATATYPE_DEF(T, S, C) \
|
||||
template <> int TensorSpec::getDataType<T>();
|
||||
|
||||
TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_DEF)
|
||||
|
||||
#undef TFUTILS_GETDATATYPE_DEF
|
||||
} // namespace llvm
|
||||
|
||||
#endif // LLVM_HAVE_TF_API
|
||||
|
|
|
@ -13,9 +13,10 @@
|
|||
#include "llvm/Config/config.h"
|
||||
#if defined(LLVM_HAVE_TF_API)
|
||||
|
||||
#include "llvm/Analysis/Utils/TFUtils.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Analysis/Utils/TFUtils.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
#include "llvm/Support/ManagedStatic.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
|
@ -83,6 +84,41 @@ private:
|
|||
std::vector<TF_Tensor *> Output;
|
||||
};
|
||||
|
||||
Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
|
||||
const json::Value &Value) {
|
||||
auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> {
|
||||
std::string S;
|
||||
llvm::raw_string_ostream OS(S);
|
||||
OS << Value;
|
||||
Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
|
||||
return None;
|
||||
};
|
||||
json::ObjectMapper Mapper(Value);
|
||||
if (!Mapper)
|
||||
return EmitError("Value is not a dict");
|
||||
|
||||
std::string TensorName;
|
||||
int TensorPort = -1;
|
||||
std::string TensorType;
|
||||
std::vector<int64_t> TensorShape;
|
||||
|
||||
if (!Mapper.map<std::string>("name", TensorName))
|
||||
return EmitError("'name' property not present or not a string");
|
||||
if (!Mapper.map<std::string>("type", TensorType))
|
||||
return EmitError("'type' property not present or not a string");
|
||||
if (!Mapper.map<int>("port", TensorPort))
|
||||
return EmitError("'port' property not present or not an int");
|
||||
if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
|
||||
return EmitError("'shape' property not present or not an int array");
|
||||
|
||||
#define PARSE_TYPE(T, S, E) \
|
||||
if (TensorType == #S) \
|
||||
return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
|
||||
TFUTILS_SUPPORTED_TYPES(PARSE_TYPE)
|
||||
#undef PARSE_TYPE
|
||||
return None;
|
||||
}
|
||||
|
||||
class TFModelEvaluatorImpl {
|
||||
public:
|
||||
TFModelEvaluatorImpl(StringRef SavedModelPath,
|
||||
|
@ -249,25 +285,12 @@ void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
|
|||
return TF_TensorData(Impl->getOutput()[Index]);
|
||||
}
|
||||
|
||||
template <> int TensorSpec::getDataType<float>() { return TF_FLOAT; }
|
||||
#define TFUTILS_GETDATATYPE_IMPL(T, S, E) \
|
||||
template <> int TensorSpec::getDataType<T>() { return TF_##E; }
|
||||
|
||||
template <> int TensorSpec::getDataType<double>() { return TF_DOUBLE; }
|
||||
TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL)
|
||||
|
||||
template <> int TensorSpec::getDataType<int8_t>() { return TF_INT8; }
|
||||
|
||||
template <> int TensorSpec::getDataType<uint8_t>() { return TF_UINT8; }
|
||||
|
||||
template <> int TensorSpec::getDataType<int16_t>() { return TF_INT16; }
|
||||
|
||||
template <> int TensorSpec::getDataType<uint16_t>() { return TF_UINT16; }
|
||||
|
||||
template <> int TensorSpec::getDataType<int32_t>() { return TF_INT32; }
|
||||
|
||||
template <> int TensorSpec::getDataType<uint32_t>() { return TF_UINT32; }
|
||||
|
||||
template <> int TensorSpec::getDataType<int64_t>() { return TF_INT64; }
|
||||
|
||||
template <> int TensorSpec::getDataType<uint64_t>() { return TF_UINT64; }
|
||||
#undef TFUTILS_GETDATATYPE_IMPL
|
||||
|
||||
TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
|
||||
TFModelEvaluator::~TFModelEvaluator() {}
|
||||
|
|
|
@ -94,3 +94,32 @@ TEST(TFUtilsTest, EvalError) {
|
|||
EXPECT_FALSE(ER.hasValue());
|
||||
EXPECT_FALSE(Evaluator.isValid());
|
||||
}
|
||||
|
||||
TEST(TFUtilsTest, JSONParsing) {
|
||||
auto Value = json::parse(
|
||||
R"({"name": "tensor_name",
|
||||
"port": 2,
|
||||
"type": "int32",
|
||||
"shape":[1,4]
|
||||
})");
|
||||
EXPECT_TRUE(!!Value);
|
||||
LLVMContext Ctx;
|
||||
Optional<TensorSpec> Spec = getTensorSpecFromJSON(Ctx, *Value);
|
||||
EXPECT_TRUE(Spec.hasValue());
|
||||
EXPECT_EQ(*Spec, TensorSpec::createSpec<int32_t>("tensor_name", {1, 4}, 2));
|
||||
}
|
||||
|
||||
TEST(TFUtilsTest, JSONParsingInvalidTensorType) {
|
||||
auto Value = json::parse(
|
||||
R"(
|
||||
{"name": "tensor_name",
|
||||
"port": 2,
|
||||
"type": "no such type",
|
||||
"shape":[1,4]
|
||||
}
|
||||
)");
|
||||
EXPECT_TRUE(!!Value);
|
||||
LLVMContext Ctx;
|
||||
auto Spec = getTensorSpecFromJSON(Ctx, *Value);
|
||||
EXPECT_FALSE(Spec.hasValue());
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue