llvm-project/mlir/tools/mlir-reduce/ReductionNode.cpp

131 lines
4.2 KiB
C++

//===- ReductionNode.cpp - Reduction Node Implementation -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the reduction nodes which are used to track of the
// metadata for a specific generated variant within a reduction pass and are the
// building blocks of the reduction tree structure. A reduction tree is used to
// keep track of the different generated variants throughout a reduction pass in
// the MLIR Reduce tool.
//
//===----------------------------------------------------------------------===//
#include "mlir/Reducer/ReductionNode.h"
using namespace mlir;
/// Sets up the metadata and links the node to its parent.
ReductionNode::ReductionNode(ModuleOp module, ReductionNode *parent)
: module(module), evaluated(false) {
if (parent != nullptr)
parent->linkVariant(this);
}
ReductionNode::ReductionNode(ModuleOp module, ReductionNode *parent,
std::vector<bool> transformSpace)
: module(module), evaluated(false), transformSpace(transformSpace) {
if (parent != nullptr)
parent->linkVariant(this);
}
/// Calculates and updates the size and interesting values of the module.
void ReductionNode::measureAndTest(const Tester &test) {
SmallString<128> filepath;
int fd;
// Print module to temprary file.
std::error_code ec =
llvm::sys::fs::createTemporaryFile("mlir-reduce", "mlir", fd, filepath);
if (ec)
llvm::report_fatal_error("Error making unique filename: " + ec.message());
llvm::ToolOutputFile out(filepath, fd);
module.print(out.os());
out.os().close();
if (out.os().has_error())
llvm::report_fatal_error("Error emitting bitcode to file '" + filepath);
size = out.os().tell();
interesting = test.isInteresting(filepath);
evaluated = true;
}
/// Returns true if the size and interestingness have been calculated.
bool ReductionNode::isEvaluated() const { return evaluated; }
/// Returns the size in bytes of the module.
int ReductionNode::getSize() const { return size; }
/// Returns true if the module exhibits the interesting behavior.
bool ReductionNode::isInteresting() const { return interesting; }
/// Returns the pointers to the child variants.
ReductionNode *ReductionNode::getVariant(unsigned long index) const {
if (index < variants.size())
return variants[index].get();
return nullptr;
}
/// Returns the number of child variants.
int ReductionNode::variantsSize() const { return variants.size(); }
/// Returns true if the child variants vector is empty.
bool ReductionNode::variantsEmpty() const { return variants.empty(); }
/// Link a child variant node.
void ReductionNode::linkVariant(ReductionNode *newVariant) {
std::unique_ptr<ReductionNode> ptrVariant(newVariant);
variants.push_back(std::move(ptrVariant));
}
/// Sort the child variants and remove the uninteresting ones.
void ReductionNode::organizeVariants(const Tester &test) {
// Ensure all variants are evaluated.
for (auto &var : variants)
if (!var->isEvaluated())
var->measureAndTest(test);
// Sort variants by interestingness and size.
llvm::array_pod_sort(
variants.begin(), variants.end(), [](const auto *lhs, const auto *rhs) {
if (lhs->get()->isInteresting() && !rhs->get()->isInteresting())
return 0;
if (!lhs->get()->isInteresting() && rhs->get()->isInteresting())
return 1;
return (lhs->get()->getSize(), rhs->get()->getSize());
});
int interestingCount = 0;
for (auto &var : variants) {
if (var->isInteresting()) {
++interestingCount;
} else {
break;
}
}
// Remove uninteresting variants.
variants.resize(interestingCount);
}
/// Returns the number of non transformed indices.
int ReductionNode::transformSpaceSize() {
return std::count(transformSpace.begin(), transformSpace.end(), false);
}
/// Returns a vector of the transformable indices in the Module.
const std::vector<bool> ReductionNode::getTransformSpace() {
return transformSpace;
}