forked from OSchip/llvm-project
[NFC][MLGO]Add RTTI support for MLModelRunner and simplify runner setup
This commit is contained in:
parent
e627f4ce0d
commit
a120fdd337
|
@ -41,8 +41,13 @@ public:
|
|||
getTensorUntyped(static_cast<size_t>(FeatureID)));
|
||||
}
|
||||
|
||||
enum class Kind : int { Unknown, Release, Development, NoOp };
|
||||
Kind getKind() const { return Type; }
|
||||
|
||||
protected:
|
||||
MLModelRunner(LLVMContext &Ctx) : Ctx(Ctx) {}
|
||||
MLModelRunner(LLVMContext &Ctx, Kind Type) : Ctx(Ctx), Type(Type) {
|
||||
assert(Type != Kind::Unknown);
|
||||
}
|
||||
virtual void *evaluateUntyped() = 0;
|
||||
virtual void *getTensorUntyped(size_t Index) = 0;
|
||||
const void *getTensorUntyped(size_t Index) const {
|
||||
|
@ -50,6 +55,7 @@ protected:
|
|||
}
|
||||
|
||||
LLVMContext &Ctx;
|
||||
const Kind Type;
|
||||
};
|
||||
} // namespace llvm
|
||||
|
||||
|
|
|
@ -26,17 +26,11 @@ namespace llvm {
|
|||
/// sacrificed for ease of use while training.
|
||||
class ModelUnderTrainingRunner final : public MLModelRunner {
|
||||
public:
|
||||
ModelUnderTrainingRunner(LLVMContext &Ctx, const std::string &ModelPath,
|
||||
const std::vector<TensorSpec> &InputSpecs,
|
||||
const std::vector<LoggedFeatureSpec> &OutputSpecs);
|
||||
|
||||
// Disallows copy and assign.
|
||||
ModelUnderTrainingRunner(const ModelUnderTrainingRunner &) = delete;
|
||||
ModelUnderTrainingRunner &
|
||||
operator=(const ModelUnderTrainingRunner &) = delete;
|
||||
|
||||
bool isValid() const { return !!Evaluator; }
|
||||
|
||||
const std::vector<LoggedFeatureSpec> &outputLoggedFeatureSpecs() const {
|
||||
return OutputSpecs;
|
||||
}
|
||||
|
@ -45,13 +39,27 @@ public:
|
|||
lastEvaluationResult() const {
|
||||
return LastEvaluationResult;
|
||||
}
|
||||
static bool classof(const MLModelRunner *R) {
|
||||
return R->getKind() == MLModelRunner::Kind::Development;
|
||||
}
|
||||
|
||||
static std::unique_ptr<ModelUnderTrainingRunner>
|
||||
createAndEnsureValid(LLVMContext &Ctx, const std::string &ModelPath,
|
||||
StringRef DecisionName,
|
||||
const std::vector<TensorSpec> &InputSpecs,
|
||||
StringRef OutputSpecsPathOverride = "");
|
||||
|
||||
private:
|
||||
ModelUnderTrainingRunner(LLVMContext &Ctx, const std::string &ModelPath,
|
||||
const std::vector<TensorSpec> &InputSpecs,
|
||||
const std::vector<LoggedFeatureSpec> &OutputSpecs);
|
||||
|
||||
std::unique_ptr<TFModelEvaluator> Evaluator;
|
||||
const std::vector<LoggedFeatureSpec> OutputSpecs;
|
||||
Optional<TFModelEvaluator::EvaluationResult> LastEvaluationResult;
|
||||
void *evaluateUntyped() override;
|
||||
void *getTensorUntyped(size_t Index) override;
|
||||
bool isValid() const { return !!Evaluator; }
|
||||
};
|
||||
|
||||
} // namespace llvm
|
||||
|
|
|
@ -26,6 +26,10 @@ public:
|
|||
NoInferenceModelRunner(LLVMContext &Ctx,
|
||||
const std::vector<TensorSpec> &Inputs);
|
||||
|
||||
static bool classof(const MLModelRunner *R) {
|
||||
return R->getKind() == MLModelRunner::Kind::NoOp;
|
||||
}
|
||||
|
||||
private:
|
||||
void *evaluateUntyped() override {
|
||||
llvm_unreachable("We shouldn't call run on this model runner.");
|
||||
|
|
|
@ -29,7 +29,8 @@ public:
|
|||
ReleaseModeModelRunner(LLVMContext &Ctx, const FType &FeatureNames,
|
||||
StringRef DecisionName, StringRef FeedPrefix = "feed_",
|
||||
StringRef FetchPrefix = "fetch_")
|
||||
: MLModelRunner(Ctx), CompiledModel(std::make_unique<TGen>()) {
|
||||
: MLModelRunner(Ctx, MLModelRunner::Kind::Release),
|
||||
CompiledModel(std::make_unique<TGen>()) {
|
||||
assert(CompiledModel && "The CompiledModel should be valid");
|
||||
|
||||
const size_t FeatureCount = FeatureNames.size();
|
||||
|
@ -49,6 +50,10 @@ public:
|
|||
|
||||
virtual ~ReleaseModeModelRunner() = default;
|
||||
|
||||
static bool classof(const MLModelRunner *R) {
|
||||
return R->getKind() == MLModelRunner::Kind::Release;
|
||||
}
|
||||
|
||||
private:
|
||||
void *evaluateUntyped() override {
|
||||
CompiledModel->Run();
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include "llvm/Config/config.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#if defined(LLVM_HAVE_TF_API)
|
||||
|
||||
#include "llvm/Analysis/CallGraph.h"
|
||||
|
@ -150,7 +151,7 @@ public:
|
|||
DevelopmentModeMLInlineAdvisor(
|
||||
Module &M, ModuleAnalysisManager &MAM,
|
||||
std::unique_ptr<MLModelRunner> ModelRunner,
|
||||
std::function<bool(CallBase &)> GetDefaultAdvice, bool IsDoingInference,
|
||||
std::function<bool(CallBase &)> GetDefaultAdvice,
|
||||
std::unique_ptr<TrainingLogger> Logger);
|
||||
|
||||
size_t getTotalSizeEstimate();
|
||||
|
@ -341,10 +342,11 @@ void TrainingLogger::print() {
|
|||
DevelopmentModeMLInlineAdvisor::DevelopmentModeMLInlineAdvisor(
|
||||
Module &M, ModuleAnalysisManager &MAM,
|
||||
std::unique_ptr<MLModelRunner> ModelRunner,
|
||||
std::function<bool(CallBase &)> GetDefaultAdvice, bool IsDoingInference,
|
||||
std::function<bool(CallBase &)> GetDefaultAdvice,
|
||||
std::unique_ptr<TrainingLogger> Logger)
|
||||
: MLInlineAdvisor(M, MAM, std::move(ModelRunner)),
|
||||
GetDefaultAdvice(GetDefaultAdvice), IsDoingInference(IsDoingInference),
|
||||
GetDefaultAdvice(GetDefaultAdvice),
|
||||
IsDoingInference(isa<ModelUnderTrainingRunner>(getModelRunner())),
|
||||
Logger(std::move(Logger)),
|
||||
InitialNativeSize(isLogging() ? getTotalSizeEstimate() : 0),
|
||||
CurrentNativeSize(InitialNativeSize) {
|
||||
|
@ -422,30 +424,20 @@ std::unique_ptr<InlineAdvisor> llvm::getDevelopmentModeAdvisor(
|
|||
std::function<bool(CallBase &)> GetDefaultAdvice) {
|
||||
auto &Ctx = M.getContext();
|
||||
std::unique_ptr<MLModelRunner> Runner;
|
||||
ModelUnderTrainingRunner *MUTRPtr = nullptr;
|
||||
bool IsDoingInference = false;
|
||||
if (TFModelUnderTrainingPath.empty())
|
||||
Runner.reset(new NoInferenceModelRunner(Ctx, getInputFeatures()));
|
||||
else {
|
||||
std::unique_ptr<ModelUnderTrainingRunner> MUTR;
|
||||
if (auto MaybeOutputSpecs = loadOutputSpecs(
|
||||
Ctx, DecisionName, TFModelUnderTrainingPath, TFOutputSpecOverride))
|
||||
MUTR = std::make_unique<ModelUnderTrainingRunner>(
|
||||
Ctx, TFModelUnderTrainingPath, getInputFeatures(), *MaybeOutputSpecs);
|
||||
if (!MUTR || !MUTR->isValid()) {
|
||||
Ctx.emitError("Could not load the policy model from the provided path");
|
||||
return nullptr;
|
||||
}
|
||||
IsDoingInference = true;
|
||||
MUTRPtr = MUTR.get();
|
||||
Runner = std::move(MUTR);
|
||||
}
|
||||
else
|
||||
Runner = ModelUnderTrainingRunner::createAndEnsureValid(
|
||||
Ctx, TFModelUnderTrainingPath, DecisionName, getInputFeatures(),
|
||||
TFOutputSpecOverride);
|
||||
if (!Runner)
|
||||
return nullptr;
|
||||
std::unique_ptr<TrainingLogger> Logger;
|
||||
if (!TrainingLog.empty())
|
||||
Logger = std::make_unique<TrainingLogger>(TrainingLog, MUTRPtr);
|
||||
Logger = std::make_unique<TrainingLogger>(
|
||||
TrainingLog, dyn_cast<ModelUnderTrainingRunner>(Runner.get()));
|
||||
|
||||
return std::make_unique<DevelopmentModeMLInlineAdvisor>(
|
||||
M, MAM, std::move(Runner), GetDefaultAdvice, IsDoingInference,
|
||||
std::move(Logger));
|
||||
M, MAM, std::move(Runner), GetDefaultAdvice, std::move(Logger));
|
||||
}
|
||||
#endif // defined(LLVM_HAVE_TF_API)
|
||||
|
|
|
@ -22,7 +22,8 @@ ModelUnderTrainingRunner::ModelUnderTrainingRunner(
|
|||
LLVMContext &Ctx, const std::string &ModelPath,
|
||||
const std::vector<TensorSpec> &InputSpecs,
|
||||
const std::vector<LoggedFeatureSpec> &OutputSpecs)
|
||||
: MLModelRunner(Ctx), OutputSpecs(OutputSpecs) {
|
||||
: MLModelRunner(Ctx, MLModelRunner::Kind::Development),
|
||||
OutputSpecs(OutputSpecs) {
|
||||
Evaluator = std::make_unique<TFModelEvaluator>(
|
||||
ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; },
|
||||
OutputSpecs.size());
|
||||
|
@ -46,4 +47,21 @@ void *ModelUnderTrainingRunner::getTensorUntyped(size_t Index) {
|
|||
return Evaluator->getUntypedInput(Index);
|
||||
}
|
||||
|
||||
std::unique_ptr<ModelUnderTrainingRunner>
|
||||
ModelUnderTrainingRunner::createAndEnsureValid(
|
||||
LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
|
||||
const std::vector<TensorSpec> &InputSpecs,
|
||||
StringRef OutputSpecsPathOverride) {
|
||||
std::unique_ptr<ModelUnderTrainingRunner> MUTR;
|
||||
if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath,
|
||||
OutputSpecsPathOverride))
|
||||
MUTR.reset(new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs,
|
||||
*MaybeOutputSpecs));
|
||||
if (MUTR && MUTR->isValid())
|
||||
return MUTR;
|
||||
|
||||
Ctx.emitError("Could not load the policy model from the provided path");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#endif // defined(LLVM_HAVE_TF_API)
|
||||
|
|
|
@ -20,7 +20,7 @@ using namespace llvm;
|
|||
|
||||
NoInferenceModelRunner::NoInferenceModelRunner(
|
||||
LLVMContext &Ctx, const std::vector<TensorSpec> &Inputs)
|
||||
: MLModelRunner(Ctx) {
|
||||
: MLModelRunner(Ctx, MLModelRunner::Kind::NoOp) {
|
||||
ValuesBuffer.reserve(Inputs.size());
|
||||
for (const auto &TS : Inputs)
|
||||
ValuesBuffer.push_back(std::make_unique<char[]>(TS.getElementCount() *
|
||||
|
|
Loading…
Reference in New Issue