Make createMaterializeVectorsPass take a vectorSize parameter - NFC

This CL allows the programmatic control of the target hardware vector size when creating a MaterializeVectorsPass.
This is useful for registering passes for the tutorial.

PiperOrigin-RevId: 240996136
This commit is contained in:
Nicolas Vasilache 2019-03-29 09:47:30 -07:00 committed by jpienaar
parent 5303587448
commit f93a5be65f
3 changed files with 24 additions and 14 deletions

View File

@ -52,7 +52,8 @@ createVectorizePass(llvm::ArrayRef<int64_t> virtualVectorSize);
FunctionPassBase *createVectorizerTestPass();
/// Creates a pass to lower super-vectors to target-dependent HW vectors.
FunctionPassBase *createMaterializeVectorsPass();
FunctionPassBase *
createMaterializeVectorsPass(llvm::ArrayRef<int64_t> vectorSize);
/// Creates a loop unrolling pass with the provided parameters.
/// 'getUnrollFactor' is a function callback for clients to supply a function

View File

@ -185,9 +185,8 @@ struct MaterializationState {
/// of the type and we assert everything is f32.
/// TODO(ntv): relax the assumptions on admissible element type once a
/// contract exists.
MaterializationState() : hwVectorSize(clVectorSize.size(), 0) {
std::copy(clVectorSize.begin(), clVectorSize.end(), hwVectorSize.begin());
}
MaterializationState(SmallVector<int64_t, 8> sizes) : hwVectorSize(sizes) {}
SmallVector<int64_t, 8> hwVectorSize;
VectorType superVectorType;
VectorType hwVectorType;
@ -195,7 +194,18 @@ struct MaterializationState {
DenseMap<Value *, Value *> *substitutionsMap;
};
/// Base state for the vector materialization pass.
/// Command line arguments are preempted by non-empty pass arguments.
struct MaterializeVectorsPass : public FunctionPass<MaterializeVectorsPass> {
MaterializeVectorsPass()
: hwVectorSize(clVectorSize.begin(), clVectorSize.end()) {}
MaterializeVectorsPass(ArrayRef<int64_t> hwVectorSize)
: MaterializeVectorsPass() {
if (!hwVectorSize.empty())
this->hwVectorSize.assign(hwVectorSize.begin(), hwVectorSize.end());
}
SmallVector<int64_t, 8> hwVectorSize;
void runOnFunction() override;
};
@ -739,11 +749,11 @@ void MaterializeVectorsPass::runOnFunction() {
LLVM_DEBUG(dbgs() << "\nMaterializeVectors on Function\n");
LLVM_DEBUG(f->print(dbgs()));
MaterializationState state;
MaterializationState state(hwVectorSize);
// Get the hardware vector type.
// TODO(ntv): get elemental type from super-vector type rather than force f32.
auto subVectorType =
VectorType::get(state.hwVectorSize, FloatType::getF32(&getContext()));
VectorType::get(hwVectorSize, FloatType::getF32(&getContext()));
// Capture terminators; i.e. vector_transfer_write ops involving a strict
// super-vector of subVectorType.
@ -765,8 +775,9 @@ void MaterializeVectorsPass::runOnFunction() {
signalPassFailure();
}
FunctionPassBase *mlir::createMaterializeVectorsPass() {
return new MaterializeVectorsPass();
FunctionPassBase *
mlir::createMaterializeVectorsPass(llvm::ArrayRef<int64_t> vectorSize) {
return new MaterializeVectorsPass(vectorSize);
}
static PassRegistration<MaterializeVectorsPass>

View File

@ -620,12 +620,10 @@ struct Vectorize : public FunctionPass<Vectorize> {
} // end anonymous namespace
Vectorize::Vectorize() {
this->vectorSizes.assign(clVirtualVectorSize.begin(),
clVirtualVectorSize.end());
this->fastestVaryingPattern.assign(clFastestVaryingPattern.begin(),
clFastestVaryingPattern.end());
}
Vectorize::Vectorize()
: vectorSizes(clVirtualVectorSize.begin(), clVirtualVectorSize.end()),
fastestVaryingPattern(clFastestVaryingPattern.begin(),
clFastestVaryingPattern.end()) {}
Vectorize::Vectorize(ArrayRef<int64_t> virtualVectorSize) : Vectorize() {
if (!virtualVectorSize.empty()) {