2020-08-21 12:59:23 +08:00
|
|
|
//===- ReductionTreeUtils.cpp - Reduction Tree Utilities ------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// This file defines the Reduction Tree Utilities. It defines pass independent
|
|
|
|
// methods that help in a reduction pass of the MLIR Reduce tool.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "mlir/Reducer/ReductionTreeUtils.h"
|
|
|
|
|
|
|
|
#define DEBUG_TYPE "mlir-reduce"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
|
|
|
/// Update the golden module's content with that of the reduced module.
|
|
|
|
void ReductionTreeUtils::updateGoldenModule(ModuleOp &golden,
|
|
|
|
ModuleOp reduced) {
|
|
|
|
golden.getBody()->clear();
|
|
|
|
|
|
|
|
golden.getBody()->getOperations().splice(golden.getBody()->begin(),
|
|
|
|
reduced.getBody()->getOperations());
|
|
|
|
}
|
|
|
|
|
2020-10-29 03:03:15 +08:00
|
|
|
/// Update the smallest node traversed so far in the reduction tree and
|
2020-08-21 12:59:23 +08:00
|
|
|
/// print the debugging information for the currNode being traversed.
|
|
|
|
void ReductionTreeUtils::updateSmallestNode(ReductionNode *currNode,
|
|
|
|
ReductionNode *&smallestNode,
|
|
|
|
std::vector<int> path) {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "\nTree Path: root");
|
2020-08-25 08:58:17 +08:00
|
|
|
#ifndef NDEBUG
|
2020-08-21 12:59:23 +08:00
|
|
|
for (int nodeIndex : path)
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << " -> " << nodeIndex);
|
2020-08-25 08:58:17 +08:00
|
|
|
#endif
|
2020-08-21 12:59:23 +08:00
|
|
|
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "\nSize (chars): " << currNode->getSize());
|
|
|
|
if (currNode->getSize() < smallestNode->getSize()) {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << " - new smallest node!");
|
|
|
|
smallestNode = currNode;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Create a transform space index vector based on the specified number of
|
|
|
|
/// indices.
|
|
|
|
std::vector<bool> ReductionTreeUtils::createTransformSpace(ModuleOp module,
|
|
|
|
int numIndices) {
|
|
|
|
std::vector<bool> transformSpace;
|
|
|
|
for (int i = 0; i < numIndices; ++i)
|
|
|
|
transformSpace.push_back(false);
|
|
|
|
|
|
|
|
return transformSpace;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Translate section start and end into a vector of ranges specifying the
|
|
|
|
/// section in the non transformed indices in the transform space.
|
|
|
|
static std::vector<std::tuple<int, int>> getRanges(std::vector<bool> tSpace,
|
|
|
|
int start, int end) {
|
|
|
|
std::vector<std::tuple<int, int>> ranges;
|
|
|
|
int rangeStart = 0;
|
|
|
|
int rangeEnd = 0;
|
|
|
|
bool inside = false;
|
|
|
|
int transformableCount = 0;
|
|
|
|
|
|
|
|
for (auto element : llvm::enumerate(tSpace)) {
|
|
|
|
int index = element.index();
|
|
|
|
bool value = element.value();
|
|
|
|
|
|
|
|
if (start <= transformableCount && transformableCount < end) {
|
|
|
|
if (!value && !inside) {
|
|
|
|
inside = true;
|
|
|
|
rangeStart = index;
|
|
|
|
}
|
|
|
|
if (value && inside) {
|
|
|
|
rangeEnd = index;
|
|
|
|
ranges.push_back(std::make_tuple(rangeStart, rangeEnd));
|
|
|
|
inside = false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!value)
|
|
|
|
transformableCount++;
|
|
|
|
|
|
|
|
if (transformableCount == end && inside) {
|
|
|
|
ranges.push_back(std::make_tuple(rangeStart, index + 1));
|
|
|
|
inside = false;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return ranges;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Create the specified number of variants by applying the transform method
|
2020-10-29 03:03:15 +08:00
|
|
|
/// to different ranges of indices in the parent module. The isDeletion boolean
|
2020-08-21 12:59:23 +08:00
|
|
|
/// specifies if the transformation is the deletion of indices.
|
|
|
|
void ReductionTreeUtils::createVariants(
|
|
|
|
ReductionNode *parent, const Tester &test, int numVariants,
|
|
|
|
llvm::function_ref<void(ModuleOp, int, int)> transform, bool isDeletion) {
|
|
|
|
std::vector<bool> newTSpace;
|
|
|
|
ModuleOp module = parent->getModule();
|
|
|
|
|
|
|
|
std::vector<bool> parentTSpace = parent->getTransformSpace();
|
|
|
|
int indexCount = parent->transformSpaceSize();
|
|
|
|
std::vector<std::tuple<int, int>> ranges;
|
|
|
|
|
|
|
|
// No new variants can be created.
|
|
|
|
if (indexCount == 0)
|
|
|
|
return;
|
|
|
|
|
|
|
|
// Create a single variant by transforming the unique index.
|
|
|
|
if (indexCount == 1) {
|
|
|
|
ModuleOp variantModule = module.clone();
|
|
|
|
if (isDeletion) {
|
|
|
|
transform(variantModule, 0, 1);
|
|
|
|
} else {
|
|
|
|
ranges = getRanges(parentTSpace, 0, parentTSpace.size());
|
|
|
|
transform(variantModule, std::get<0>(ranges[0]), std::get<1>(ranges[0]));
|
|
|
|
}
|
|
|
|
|
|
|
|
new ReductionNode(variantModule, parent, newTSpace);
|
|
|
|
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create the specified number of variants.
|
|
|
|
for (int i = 0; i < numVariants; ++i) {
|
|
|
|
ModuleOp variantModule = module.clone();
|
|
|
|
newTSpace = parent->getTransformSpace();
|
|
|
|
int sectionSize = indexCount / numVariants;
|
|
|
|
int sectionStart = sectionSize * i;
|
|
|
|
int sectionEnd = sectionSize * (i + 1);
|
|
|
|
|
|
|
|
if (i == numVariants - 1)
|
|
|
|
sectionEnd = indexCount;
|
|
|
|
|
|
|
|
if (isDeletion)
|
|
|
|
transform(variantModule, sectionStart, sectionEnd);
|
|
|
|
|
|
|
|
ranges = getRanges(parentTSpace, sectionStart, sectionEnd);
|
|
|
|
|
|
|
|
for (auto range : ranges) {
|
|
|
|
int rangeStart = std::get<0>(range);
|
|
|
|
int rangeEnd = std::get<1>(range);
|
|
|
|
|
|
|
|
for (int x = rangeStart; x < rangeEnd; ++x)
|
|
|
|
newTSpace[x] = true;
|
|
|
|
|
|
|
|
if (!isDeletion)
|
|
|
|
transform(variantModule, rangeStart, rangeEnd);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create Reduction Node in the Reduction tree
|
|
|
|
new ReductionNode(variantModule, parent, newTSpace);
|
|
|
|
}
|
|
|
|
}
|