[mlir][linalg] lower index operations during linalg to vector lowering.

The patch extends the vectorization pass to lower linalg index operations to vector code. It allocates constant 1d vectors that enumerate the indexes along the iteration dimensions and broadcasts/transposes these 1d vectors to the iteration space.

Differential Revision: https://reviews.llvm.org/D100373
This commit is contained in:
Tobias Gysi 2021-04-20 11:26:44 +00:00
parent e156f2515c
commit b9715156ff
7 changed files with 134 additions and 27 deletions

View File

@ -1242,11 +1242,21 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/// appear in the operands.
SmallVector<Value, 4> createFlatListOfOperandDims(OpBuilder &, Location);
/// Return the flat list of all operands' static dimension sizes in the
/// order they appear in the operands. All operand dimension sizes have to
/// be statically known.
SmallVector<int64_t, 4> createFlatListOfOperandStaticDims();
/// Create the loop ranges to materialize the computation over the current
/// operands. This is done by applying `getShapesToLoopsMap` to
/// `createFlatListOfOperandDims`.
SmallVector<Range, 4> createLoopRanges(OpBuilder &b, Location loc);
/// Compute the static loop sizes necessary to vectorize the computation.
/// This is done by applying `getShapesToLoopsMap` to
/// `createFlatListOfOperandStaticDims`.
SmallVector<int64_t, 4> computeStaticLoopSizes();
/// Returns all the operands past the inputs, output_buffers and
/// init_tensors operands. Asserts that these operands are value types to
/// allow transformations like tiling to just use the values when cloning

View File

@ -124,6 +124,7 @@ public:
DenseIntElementsAttr getBoolVectorAttr(ArrayRef<bool> values);
DenseIntElementsAttr getI32VectorAttr(ArrayRef<int32_t> values);
DenseIntElementsAttr getI64VectorAttr(ArrayRef<int64_t> values);
DenseIntElementsAttr getIndexVectorAttr(ArrayRef<int64_t> values);
/// Tensor-typed DenseIntElementsAttr getters. `values` can be empty.
/// These are generally preferable for representing general lists of integers

View File

@ -193,6 +193,16 @@ SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
return res;
}
SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
SmallVector<int64_t, 4> res;
for (Value v : getShapedOperands()) {
ShapedType t = v.getType().template cast<ShapedType>();
assert(t.hasStaticShape() && "expected operands to have static shapes");
llvm::append_range(res, t.getShape());
}
return res;
}
SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
AffineMap map = getLoopsToShapesMap();
unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
@ -211,6 +221,19 @@ SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
return res;
}
SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
AffineMap map = getLoopsToShapesMap();
unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims();
SmallVector<int64_t, 4> res(numDims, 0);
for (unsigned idx = 0; idx < numRes; ++idx) {
auto result = map.getResult(idx);
if (auto d = result.dyn_cast<AffineDimExpr>())
res[d.getPosition()] = allShapeSizes[idx];
}
return res;
}
/// Visitor to check if any of the given set of positions from AffineDimExprs
/// are used within an AffineExpr.
struct HasAffineDimExprVisitor

View File

@ -462,8 +462,7 @@ mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
// TODO: remove hasIndexSemantics check once index ops are supported.
if (!linalgOp || linalgOp.hasIndexSemantics())
if (!linalgOp)
return failure();
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();

View File

@ -166,6 +166,42 @@ vectorizeLinalgYield(OpBuilder &builder, Operation *op,
return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
}
/// Helper function to vectorize the index operations of a `linalgOp`. Return
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
/// should map the produced operations. This function is meant to be used as a
/// CustomVectorizationHook.
static VectorizationResult
vectorizeLinalgIndex(OpBuilder &builder, Operation *op, LinalgOp linalgOp) {
IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
if (!indexOp)
return VectorizationResult{VectorizationStatus::Failure, nullptr};
auto loc = indexOp.getLoc();
// Compute the static loop sizes of the index op.
auto targetShape = linalgOp.computeStaticLoopSizes();
// Compute a one-dimensional index vector for the index op dimension.
SmallVector<int64_t> constantSeq(
llvm::seq<int64_t>(0, targetShape[indexOp.dim()]));
ConstantOp constantOp =
builder.create<ConstantOp>(loc, builder.getIndexVectorAttr(constantSeq));
// Return the one-dimensional index vector if it lives in the trailing
// dimension of the iteration space since the vectorization algorithm in this
// case can handle the broadcast.
if (indexOp.dim() == targetShape.size() - 1)
return VectorizationResult{VectorizationStatus::NewOp, constantOp};
// Otherwise permute the targetShape to move the index dimension last,
// broadcast the one-dimensional index vector to the permuted shape, and
// finally transpose the broadcasted index vector to undo the permutation.
std::swap(targetShape[indexOp.dim()], targetShape.back());
auto broadCastOp = builder.create<vector::BroadcastOp>(
loc, VectorType::get(targetShape, builder.getIndexType()), constantOp);
SmallVector<int64_t> transposition(
llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
std::swap(transposition.back(), transposition[indexOp.dim()]);
auto transposeOp =
builder.create<vector::TransposeOp>(loc, broadCastOp, transposition);
return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
}
/// Generic vectorization for a single operation `op`, given already vectorized
/// operands carried by `bvm`. Vectorization occurs as follows:
/// 1. Try to apply any of the `customVectorizationHooks` and return its
@ -245,7 +281,7 @@ static bool hasOnlyScalarElementwiseOp(Region &r) {
if (!llvm::hasSingleElement(r))
return false;
for (Operation &op : r.front()) {
if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
if (!(isa<ConstantOp, linalg::YieldOp, linalg::IndexOp>(op) ||
OpTrait::hasElementwiseMappableTraits(&op)) ||
llvm::any_of(op.getResultTypes(),
[](Type type) { return !type.isIntOrIndexOrFloat(); }))
@ -293,7 +329,9 @@ static AffineMap getTransferReadMap(LinalgOp linalgOp, unsigned argIndex) {
/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
/// load).
/// TODO: Reuse opportunities for RAR dependencies.
/// 4. Register CustomVectorizationHook for YieldOp to capture the results.
/// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
/// 4b. Register CustomVectorizationHook for IndexOp to access the iteration
/// indices.
/// 5. Iteratively call vectorizeOneOp on the region operations.
LogicalResult vectorizeAsLinalgGeneric(
OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
@ -333,16 +371,23 @@ LogicalResult vectorizeAsLinalgGeneric(
bvm.map(vectorArg, vectorRead);
}
// 4. Register CustomVectorizationHook for yieldOp.
auto hooks = llvm::to_vector<4>(customVectorizationHooks);
// 4a. Register CustomVectorizationHook for yieldOp.
CustomVectorizationHook vectorizeYield =
[&](Operation *op,
const BlockAndValueMapping &bvm) -> VectorizationResult {
return vectorizeLinalgYield(builder, op, bvm, linalgOp, newResults);
};
// Append the vectorizeYield hook.
auto hooks = llvm::to_vector<4>(customVectorizationHooks);
hooks.push_back(vectorizeYield);
// 4b. Register CustomVectorizationHook for indexOp.
CustomVectorizationHook vectorizeIndex =
[&](Operation *op,
const BlockAndValueMapping &bvm) -> VectorizationResult {
return vectorizeLinalgIndex(builder, op, linalgOp);
};
hooks.push_back(vectorizeIndex);
// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
for (Operation &op : block.getOperations()) {
VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
@ -401,9 +446,6 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
return failure();
// TODO: remove once index ops are supported.
if (linalgOp.hasIndexSemantics())
return failure();
if (isElementwise(op))
return success();
return success(isaContractionOpInterface(linalgOp));

View File

@ -120,6 +120,12 @@ DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) {
values);
}
DenseIntElementsAttr Builder::getIndexVectorAttr(ArrayRef<int64_t> values) {
return DenseIntElementsAttr::get(
VectorType::get(static_cast<int64_t>(values.size()), getIndexType()),
values);
}
DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) {
return DenseIntElementsAttr::get(
RankedTensorType::get(static_cast<int64_t>(values.size()),

View File

@ -174,6 +174,49 @@ func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
// -----
// CHECK-LABEL: func @test_vectorize_trailing_index
// CHECK-SAME: (%[[ARG0:.*]]: memref<1x2x4x8xindex>)
func @test_vectorize_trailing_index(%arg0: memref<1x2x4x8xindex>) {
// CHECK-DAG: %[[CST0:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
outs(%arg0: memref<1x2x4x8xindex>) {
^bb0(%arg1: index):
// CHECK: %[[BCST:.*]] = vector.broadcast %[[CST0]] : vector<8xindex> to vector<1x2x4x8xindex>
// CHECK: vector.transfer_write %[[BCST]], %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {{.*}} : vector<1x2x4x8xindex>, memref<1x2x4x8xindex>
%0 = linalg.index 3 : index
linalg.yield %0 : index
}
return
}
// -----
// CHECK-LABEL: func @test_vectorize_inner_index
// CHECK-SAME: (%[[ARG0:.*]]: memref<1x2x4x8xindex>)
func @test_vectorize_inner_index(%arg0: memref<1x2x4x8xindex>) {
// CHECK-DAG: %[[CST0:.*]] = constant dense<[0, 1]> : vector<2xindex>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
outs(%arg0: memref<1x2x4x8xindex>) {
^bb0(%arg1: index):
// CHECK: %[[BCST:.*]] = vector.broadcast %[[CST0]] : vector<2xindex> to vector<1x8x4x2xindex>
// CHECK: %[[TRAN:.*]] = vector.transpose %[[BCST]], [0, 3, 2, 1] : vector<1x8x4x2xindex> to vector<1x2x4x8xindex>
// CHECK: vector.transfer_write %[[TRAN]], %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {{.*}} : vector<1x2x4x8xindex>, memref<1x2x4x8xindex>
%0 = linalg.index 1 : index
linalg.yield %0 : index
}
return
}
// -----
// CHECK-LABEL: func @generic_vectorize
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x256xf32>, %[[ARG1:.*]]: memref<4x256xf32>,
// CHECK-SAME: %[[ARG2:.*]]: memref<256xf32>, %[[ARG3:.*]]: f32)
@ -252,7 +295,6 @@ func @generic_vectorize(%arg0: memref<4x256xf32>,
return
}
// -----
// CHECK-LABEL: func @generic_vectorize_tensor
@ -469,19 +511,3 @@ func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index,
} : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
return %0 : tensor<6x?x?x?xf32>
}
// -----
// CHECK-LABEL: @index_op
// CHECK: linalg.generic
func @index_op(%arg0: memref<4x8xindex>) {
linalg.generic {
indexing_maps = [affine_map<(i, j) -> (i, j)>],
iterator_types = ["parallel", "parallel"]}
outs(%arg0 : memref<4x8xindex>) {
^bb0(%arg1: index): // no predecessors
%0 = linalg.index 1 : index
linalg.yield %0 : index
}
return
}