[NFC][MLInliner] Set up the logger outside the development mode advisor

This allows us to subsequently configure the logger for the case when we
use a model evaluator and want to log additional outputs.

Differential Revision: https://reviews.llvm.org/D85577
This commit is contained in:
Mircea Trofin 2020-08-10 09:22:17 -07:00
parent 3b21a07fd7
commit d5c81be3ca
1 changed files with 25 additions and 16 deletions

View File

@ -71,14 +71,14 @@ struct InlineEvent {
/// lines up with how TF SequenceExample represents it.
class TrainingLogger final {
public:
TrainingLogger();
TrainingLogger(StringRef LogFileName);
/// Log one inlining event.
void logInlineEvent(const InlineEvent &Event,
const MLModelRunner &ModelRunner);
/// Print the stored tensors.
void print(raw_fd_ostream &OutFile);
void print();
private:
/// Write the values of one tensor as a list.
@ -156,6 +156,7 @@ private:
OutFile << " }\n";
}
StringRef LogFileName;
std::vector<InlineFeatures> Features;
std::vector<int64_t> DefaultDecisions;
std::vector<int64_t> Decisions;
@ -193,7 +194,8 @@ public:
DevelopmentModeMLInlineAdvisor(
Module &M, ModuleAnalysisManager &MAM,
std::unique_ptr<MLModelRunner> ModelRunner,
std::function<bool(CallBase &)> GetDefaultAdvice, bool IsDoingInference);
std::function<bool(CallBase &)> GetDefaultAdvice, bool IsDoingInference,
std::unique_ptr<TrainingLogger> Logger);
size_t getTotalSizeEstimate();
@ -211,11 +213,11 @@ public:
size_t getNativeSizeEstimate(const Function &F) const;
private:
bool isLogging() const { return !TrainingLog.empty(); }
bool isLogging() const { return !!Logger; }
std::function<bool(CallBase &)> GetDefaultAdvice;
TrainingLogger Logger;
const bool IsDoingInference;
std::unique_ptr<TrainingLogger> Logger;
const int32_t InitialNativeSize;
int32_t CurrentNativeSize = 0;
@ -346,7 +348,8 @@ private:
};
} // namespace
TrainingLogger::TrainingLogger() {
TrainingLogger::TrainingLogger(StringRef LogFileName)
: LogFileName(LogFileName) {
for (size_t I = 0; I < NumberOfFeatures; ++I) {
Features.push_back(InlineFeatures());
}
@ -364,7 +367,9 @@ void TrainingLogger::logInlineEvent(const InlineEvent &Event,
DefaultDecisions.push_back(Event.DefaultDecision);
}
void TrainingLogger::print(raw_fd_ostream &OutFile) {
void TrainingLogger::print() {
std::error_code EC;
raw_fd_ostream OutFile(LogFileName, EC);
size_t NumberOfRecords = Decisions.size();
if (NumberOfRecords == 0)
return;
@ -392,9 +397,11 @@ void TrainingLogger::print(raw_fd_ostream &OutFile) {
DevelopmentModeMLInlineAdvisor::DevelopmentModeMLInlineAdvisor(
Module &M, ModuleAnalysisManager &MAM,
std::unique_ptr<MLModelRunner> ModelRunner,
std::function<bool(CallBase &)> GetDefaultAdvice, bool IsDoingInference)
std::function<bool(CallBase &)> GetDefaultAdvice, bool IsDoingInference,
std::unique_ptr<TrainingLogger> Logger)
: MLInlineAdvisor(M, MAM, std::move(ModelRunner)),
GetDefaultAdvice(GetDefaultAdvice), IsDoingInference(IsDoingInference),
Logger(std::move(Logger)),
InitialNativeSize(isLogging() ? getTotalSizeEstimate() : 0),
CurrentNativeSize(InitialNativeSize) {
// We cannot have the case of neither inference nor logging.
@ -402,11 +409,8 @@ DevelopmentModeMLInlineAdvisor::DevelopmentModeMLInlineAdvisor(
}
DevelopmentModeMLInlineAdvisor::~DevelopmentModeMLInlineAdvisor() {
if (TrainingLog.empty())
return;
std::error_code ErrorCode;
raw_fd_ostream OutFile(TrainingLog, ErrorCode);
Logger.print(OutFile);
if (isLogging())
Logger->print();
}
size_t
@ -428,7 +432,7 @@ DevelopmentModeMLInlineAdvisor::getMandatoryAdvice(
return MLInlineAdvisor::getMandatoryAdvice(CB, ORE);
return std::make_unique<LoggingMLInlineAdvice>(
/*Advisor=*/this,
/*CB=*/CB, /*ORE=*/ORE, /*Recommendation=*/true, /*Logger=*/Logger,
/*CB=*/CB, /*ORE=*/ORE, /*Recommendation=*/true, /*Logger=*/*Logger,
/*CallerSizeEstimateBefore=*/getNativeSizeEstimate(*CB.getCaller()),
/*CalleeSizeEstimateBefore=*/
getNativeSizeEstimate(*CB.getCalledFunction()),
@ -446,7 +450,7 @@ DevelopmentModeMLInlineAdvisor::getAdviceFromModel(
return std::make_unique<LoggingMLInlineAdvice>(
/*Advisor=*/this,
/*CB=*/CB, /*ORE=*/ORE, /*Recommendation=*/Recommendation,
/*Logger=*/Logger,
/*Logger=*/*Logger,
/*CallerSizeEstimateBefore=*/getNativeSizeEstimate(*CB.getCaller()),
/*CalleeSizeEstimateBefore=*/
getNativeSizeEstimate(*CB.getCalledFunction()),
@ -531,7 +535,12 @@ std::unique_ptr<InlineAdvisor> llvm::getDevelopmentModeAdvisor(
}
IsDoingInference = true;
}
std::unique_ptr<TrainingLogger> Logger;
if (!TrainingLog.empty())
Logger = std::make_unique<TrainingLogger>(TrainingLog);
return std::make_unique<DevelopmentModeMLInlineAdvisor>(
M, MAM, std::move(Runner), GetDefaultAdvice, IsDoingInference);
M, MAM, std::move(Runner), GetDefaultAdvice, IsDoingInference,
std::move(Logger));
}
#endif // defined(LLVM_HAVE_TF_API)