forked from OSchip/llvm-project
[mlir] Update Toy tutorial to use callback-based loop constructors
We recently introduced support for building loops or loop nests using callbacks that populate the body. Use these in the tutorial instead of setInsertionPoint manipulations. Differential Revision: https://reviews.llvm.org/D82104
This commit is contained in:
parent
8647a9bc51
commit
68628c94cd
|
@ -50,15 +50,14 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
|
|||
}
|
||||
|
||||
/// This defines the function type used to process an iteration of a lowered
|
||||
/// loop. It takes as input a rewriter, an array of memRefOperands corresponding
|
||||
/// to the operands of the input operation, and the set of loop induction
|
||||
/// variables for the iteration. It returns a value to store at the current
|
||||
/// index of the iteration.
|
||||
using LoopIterationFn = function_ref<Value(PatternRewriter &rewriter,
|
||||
ArrayRef<Value> memRefOperands,
|
||||
ArrayRef<Value> loopIvs)>;
|
||||
/// loop. It takes as input an OpBuilder, an range of memRefOperands
|
||||
/// corresponding to the operands of the input operation, and the range of loop
|
||||
/// induction variables for the iteration. It returns a value to store at the
|
||||
/// current index of the iteration.
|
||||
using LoopIterationFn = function_ref<Value(
|
||||
OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>;
|
||||
|
||||
static void lowerOpToLoops(Operation *op, ArrayRef<Value> operands,
|
||||
static void lowerOpToLoops(Operation *op, ValueRange operands,
|
||||
PatternRewriter &rewriter,
|
||||
LoopIterationFn processIteration) {
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
|
@ -68,22 +67,21 @@ static void lowerOpToLoops(Operation *op, ArrayRef<Value> operands,
|
|||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
||||
|
||||
// Create an empty affine loop for each of the dimensions within the shape.
|
||||
SmallVector<Value, 4> loopIvs;
|
||||
for (auto dim : tensorType.getShape()) {
|
||||
auto loop = rewriter.create<AffineForOp>(loc, /*lb=*/0, dim, /*step=*/1);
|
||||
loopIvs.push_back(loop.getInductionVar());
|
||||
|
||||
// Update the rewriter insertion point to the beginning of the loop.
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
}
|
||||
|
||||
// Generate a call to the processing function with the rewriter, the memref
|
||||
// operands, and the loop induction variables. This function will return the
|
||||
// value to store at the current index.
|
||||
Value valueToStore = processIteration(rewriter, operands, loopIvs);
|
||||
rewriter.create<AffineStoreOp>(loc, valueToStore, alloc,
|
||||
llvm::makeArrayRef(loopIvs));
|
||||
// Create a nest of affine loops, with one loop per dimension of the shape.
|
||||
// The buildAffineLoopNest function takes a callback that is used to construct
|
||||
// the body of the innermost loop given a builder, a location and a range of
|
||||
// loop induction variables.
|
||||
SmallVector<int64_t, 4> lowerBounds(tensorType.getRank(), /*Value=*/0);
|
||||
SmallVector<int64_t, 4> steps(tensorType.getRank(), /*Value=*/1);
|
||||
buildAffineLoopNest(
|
||||
rewriter, loc, lowerBounds, tensorType.getShape(), steps,
|
||||
[&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
|
||||
// Call the processing function with the rewriter, the memref operands,
|
||||
// and the loop induction variables. This function will return the value
|
||||
// to store at the current index.
|
||||
Value valueToStore = processIteration(nestedBuilder, operands, ivs);
|
||||
nestedBuilder.create<AffineStoreOp>(loc, valueToStore, alloc, ivs);
|
||||
});
|
||||
|
||||
// Replace this operation with the generated alloc.
|
||||
rewriter.replaceOp(op, alloc);
|
||||
|
@ -105,8 +103,8 @@ struct BinaryOpLowering : public ConversionPattern {
|
|||
auto loc = op->getLoc();
|
||||
lowerOpToLoops(
|
||||
op, operands, rewriter,
|
||||
[loc](PatternRewriter &rewriter, ArrayRef<Value> memRefOperands,
|
||||
ArrayRef<Value> loopIvs) {
|
||||
[loc](OpBuilder &builder, ValueRange memRefOperands,
|
||||
ValueRange loopIvs) {
|
||||
// Generate an adaptor for the remapped operands of the BinaryOp. This
|
||||
// allows for using the nice named accessors that are generated by the
|
||||
// ODS.
|
||||
|
@ -115,12 +113,12 @@ struct BinaryOpLowering : public ConversionPattern {
|
|||
// Generate loads for the element of 'lhs' and 'rhs' at the inner
|
||||
// loop.
|
||||
auto loadedLhs =
|
||||
rewriter.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
|
||||
builder.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
|
||||
auto loadedRhs =
|
||||
rewriter.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
|
||||
builder.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
|
||||
|
||||
// Create the binary operation performed on the loaded values.
|
||||
return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
|
||||
return builder.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
@ -227,21 +225,21 @@ struct TransposeOpLowering : public ConversionPattern {
|
|||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
lowerOpToLoops(
|
||||
op, operands, rewriter,
|
||||
[loc](PatternRewriter &rewriter, ArrayRef<Value> memRefOperands,
|
||||
ArrayRef<Value> loopIvs) {
|
||||
// Generate an adaptor for the remapped operands of the TransposeOp.
|
||||
// This allows for using the nice named accessors that are generated
|
||||
// by the ODS.
|
||||
toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
|
||||
Value input = transposeAdaptor.input();
|
||||
lowerOpToLoops(op, operands, rewriter,
|
||||
[loc](OpBuilder &builder, ValueRange memRefOperands,
|
||||
ValueRange loopIvs) {
|
||||
// Generate an adaptor for the remapped operands of the
|
||||
// TransposeOp. This allows for using the nice named
|
||||
// accessors that are generated by the ODS.
|
||||
toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
|
||||
Value input = transposeAdaptor.input();
|
||||
|
||||
// Transpose the elements by generating a load from the reverse
|
||||
// indices.
|
||||
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
|
||||
return rewriter.create<AffineLoadOp>(loc, input, reverseIvs);
|
||||
});
|
||||
// Transpose the elements by generating a load from the
|
||||
// reverse indices.
|
||||
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
|
||||
return builder.create<AffineLoadOp>(loc, input,
|
||||
reverseIvs);
|
||||
});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -50,13 +50,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
|
|||
}
|
||||
|
||||
/// This defines the function type used to process an iteration of a lowered
|
||||
/// loop. It takes as input a rewriter, an array of memRefOperands corresponding
|
||||
/// to the operands of the input operation, and the set of loop induction
|
||||
/// variables for the iteration. It returns a value to store at the current
|
||||
/// index of the iteration.
|
||||
using LoopIterationFn = function_ref<Value(PatternRewriter &rewriter,
|
||||
ArrayRef<Value> memRefOperands,
|
||||
ArrayRef<Value> loopIvs)>;
|
||||
/// loop. It takes as input an OpBuilder, an range of memRefOperands
|
||||
/// corresponding to the operands of the input operation, and the range of loop
|
||||
/// induction variables for the iteration. It returns a value to store at the
|
||||
/// current index of the iteration.
|
||||
using LoopIterationFn = function_ref<Value(
|
||||
OpBuilder &builder, ValueRange memRefOperands, ValueRange loopIvs)>;
|
||||
|
||||
static void lowerOpToLoops(Operation *op, ArrayRef<Value> operands,
|
||||
PatternRewriter &rewriter,
|
||||
|
@ -68,22 +67,21 @@ static void lowerOpToLoops(Operation *op, ArrayRef<Value> operands,
|
|||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
||||
|
||||
// Create an empty affine loop for each of the dimensions within the shape.
|
||||
SmallVector<Value, 4> loopIvs;
|
||||
for (auto dim : tensorType.getShape()) {
|
||||
auto loop = rewriter.create<AffineForOp>(loc, /*lb=*/0, dim, /*step=*/1);
|
||||
loopIvs.push_back(loop.getInductionVar());
|
||||
|
||||
// Update the rewriter insertion point to the beginning of the loop.
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
}
|
||||
|
||||
// Generate a call to the processing function with the rewriter, the memref
|
||||
// operands, and the loop induction variables. This function will return the
|
||||
// value to store at the current index.
|
||||
Value valueToStore = processIteration(rewriter, operands, loopIvs);
|
||||
rewriter.create<AffineStoreOp>(loc, valueToStore, alloc,
|
||||
llvm::makeArrayRef(loopIvs));
|
||||
// Create a nest of affine loops, with one loop per dimension of the shape.
|
||||
// The buildAffineLoopNest function takes a callback that is used to construct
|
||||
// the body of the innermost loop given a builder, a location and a range of
|
||||
// loop induction variables.
|
||||
SmallVector<int64_t, 4> lowerBounds(tensorType.getRank(), /*Value=*/0);
|
||||
SmallVector<int64_t, 4> steps(tensorType.getRank(), /*Value=*/1);
|
||||
buildAffineLoopNest(
|
||||
rewriter, loc, lowerBounds, tensorType.getShape(), steps,
|
||||
[&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
|
||||
// Call the processing function with the rewriter, the memref operands,
|
||||
// and the loop induction variables. This function will return the value
|
||||
// to store at the current index.
|
||||
Value valueToStore = processIteration(nestedBuilder, operands, ivs);
|
||||
nestedBuilder.create<AffineStoreOp>(loc, valueToStore, alloc, ivs);
|
||||
});
|
||||
|
||||
// Replace this operation with the generated alloc.
|
||||
rewriter.replaceOp(op, alloc);
|
||||
|
@ -105,8 +103,8 @@ struct BinaryOpLowering : public ConversionPattern {
|
|||
auto loc = op->getLoc();
|
||||
lowerOpToLoops(
|
||||
op, operands, rewriter,
|
||||
[loc](PatternRewriter &rewriter, ArrayRef<Value> memRefOperands,
|
||||
ArrayRef<Value> loopIvs) {
|
||||
[loc](OpBuilder &builder, ValueRange memRefOperands,
|
||||
ValueRange loopIvs) {
|
||||
// Generate an adaptor for the remapped operands of the BinaryOp. This
|
||||
// allows for using the nice named accessors that are generated by the
|
||||
// ODS.
|
||||
|
@ -115,12 +113,12 @@ struct BinaryOpLowering : public ConversionPattern {
|
|||
// Generate loads for the element of 'lhs' and 'rhs' at the inner
|
||||
// loop.
|
||||
auto loadedLhs =
|
||||
rewriter.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
|
||||
builder.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
|
||||
auto loadedRhs =
|
||||
rewriter.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
|
||||
builder.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
|
||||
|
||||
// Create the binary operation performed on the loaded values.
|
||||
return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
|
||||
return builder.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
@ -226,21 +224,21 @@ struct TransposeOpLowering : public ConversionPattern {
|
|||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
lowerOpToLoops(
|
||||
op, operands, rewriter,
|
||||
[loc](PatternRewriter &rewriter, ArrayRef<Value> memRefOperands,
|
||||
ArrayRef<Value> loopIvs) {
|
||||
// Generate an adaptor for the remapped operands of the TransposeOp.
|
||||
// This allows for using the nice named accessors that are generated
|
||||
// by the ODS.
|
||||
toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
|
||||
Value input = transposeAdaptor.input();
|
||||
lowerOpToLoops(op, operands, rewriter,
|
||||
[loc](OpBuilder &builder, ValueRange memRefOperands,
|
||||
ValueRange loopIvs) {
|
||||
// Generate an adaptor for the remapped operands of the
|
||||
// TransposeOp. This allows for using the nice named
|
||||
// accessors that are generated by the ODS.
|
||||
toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
|
||||
Value input = transposeAdaptor.input();
|
||||
|
||||
// Transpose the elements by generating a load from the reverse
|
||||
// indices.
|
||||
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
|
||||
return rewriter.create<AffineLoadOp>(loc, input, reverseIvs);
|
||||
});
|
||||
// Transpose the elements by generating a load from the
|
||||
// reverse indices.
|
||||
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
|
||||
return builder.create<AffineLoadOp>(loc, input,
|
||||
reverseIvs);
|
||||
});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -50,15 +50,14 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
|
|||
}
|
||||
|
||||
/// This defines the function type used to process an iteration of a lowered
|
||||
/// loop. It takes as input a rewriter, an array of memRefOperands corresponding
|
||||
/// to the operands of the input operation, and the set of loop induction
|
||||
/// variables for the iteration. It returns a value to store at the current
|
||||
/// index of the iteration.
|
||||
using LoopIterationFn = function_ref<Value(PatternRewriter &rewriter,
|
||||
ArrayRef<Value> memRefOperands,
|
||||
ArrayRef<Value> loopIvs)>;
|
||||
/// loop. It takes as input an OpBuilder, an range of memRefOperands
|
||||
/// corresponding to the operands of the input operation, and the range of loop
|
||||
/// induction variables for the iteration. It returns a value to store at the
|
||||
/// current index of the iteration.
|
||||
using LoopIterationFn = function_ref<Value(
|
||||
OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>;
|
||||
|
||||
static void lowerOpToLoops(Operation *op, ArrayRef<Value> operands,
|
||||
static void lowerOpToLoops(Operation *op, ValueRange operands,
|
||||
PatternRewriter &rewriter,
|
||||
LoopIterationFn processIteration) {
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
|
@ -68,22 +67,21 @@ static void lowerOpToLoops(Operation *op, ArrayRef<Value> operands,
|
|||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
||||
|
||||
// Create an empty affine loop for each of the dimensions within the shape.
|
||||
SmallVector<Value, 4> loopIvs;
|
||||
for (auto dim : tensorType.getShape()) {
|
||||
auto loop = rewriter.create<AffineForOp>(loc, /*lb=*/0, dim, /*step=*/1);
|
||||
loopIvs.push_back(loop.getInductionVar());
|
||||
|
||||
// Update the rewriter insertion point to the beginning of the loop.
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
}
|
||||
|
||||
// Generate a call to the processing function with the rewriter, the memref
|
||||
// operands, and the loop induction variables. This function will return the
|
||||
// value to store at the current index.
|
||||
Value valueToStore = processIteration(rewriter, operands, loopIvs);
|
||||
rewriter.create<AffineStoreOp>(loc, valueToStore, alloc,
|
||||
llvm::makeArrayRef(loopIvs));
|
||||
// Create a nest of affine loops, with one loop per dimension of the shape.
|
||||
// The buildAffineLoopNest function takes a callback that is used to construct
|
||||
// the body of the innermost loop given a builder, a location and a range of
|
||||
// loop induction variables.
|
||||
SmallVector<int64_t, 4> lowerBounds(tensorType.getRank(), /*Value=*/0);
|
||||
SmallVector<int64_t, 4> steps(tensorType.getRank(), /*Value=*/1);
|
||||
buildAffineLoopNest(
|
||||
rewriter, loc, lowerBounds, tensorType.getShape(), steps,
|
||||
[&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
|
||||
// Call the processing function with the rewriter, the memref operands,
|
||||
// and the loop induction variables. This function will return the value
|
||||
// to store at the current index.
|
||||
Value valueToStore = processIteration(nestedBuilder, operands, ivs);
|
||||
nestedBuilder.create<AffineStoreOp>(loc, valueToStore, alloc, ivs);
|
||||
});
|
||||
|
||||
// Replace this operation with the generated alloc.
|
||||
rewriter.replaceOp(op, alloc);
|
||||
|
@ -105,8 +103,8 @@ struct BinaryOpLowering : public ConversionPattern {
|
|||
auto loc = op->getLoc();
|
||||
lowerOpToLoops(
|
||||
op, operands, rewriter,
|
||||
[loc](PatternRewriter &rewriter, ArrayRef<Value> memRefOperands,
|
||||
ArrayRef<Value> loopIvs) {
|
||||
[loc](OpBuilder &builder, ValueRange memRefOperands,
|
||||
ValueRange loopIvs) {
|
||||
// Generate an adaptor for the remapped operands of the BinaryOp. This
|
||||
// allows for using the nice named accessors that are generated by the
|
||||
// ODS.
|
||||
|
@ -115,12 +113,12 @@ struct BinaryOpLowering : public ConversionPattern {
|
|||
// Generate loads for the element of 'lhs' and 'rhs' at the inner
|
||||
// loop.
|
||||
auto loadedLhs =
|
||||
rewriter.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
|
||||
builder.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
|
||||
auto loadedRhs =
|
||||
rewriter.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
|
||||
builder.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
|
||||
|
||||
// Create the binary operation performed on the loaded values.
|
||||
return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
|
||||
return builder.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
@ -227,21 +225,21 @@ struct TransposeOpLowering : public ConversionPattern {
|
|||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
lowerOpToLoops(
|
||||
op, operands, rewriter,
|
||||
[loc](PatternRewriter &rewriter, ArrayRef<Value> memRefOperands,
|
||||
ArrayRef<Value> loopIvs) {
|
||||
// Generate an adaptor for the remapped operands of the TransposeOp.
|
||||
// This allows for using the nice named accessors that are generated
|
||||
// by the ODS.
|
||||
toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
|
||||
Value input = transposeAdaptor.input();
|
||||
lowerOpToLoops(op, operands, rewriter,
|
||||
[loc](OpBuilder &builder, ValueRange memRefOperands,
|
||||
ValueRange loopIvs) {
|
||||
// Generate an adaptor for the remapped operands of the
|
||||
// TransposeOp. This allows for using the nice named
|
||||
// accessors that are generated by the ODS.
|
||||
toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
|
||||
Value input = transposeAdaptor.input();
|
||||
|
||||
// Transpose the elements by generating a load from the reverse
|
||||
// indices.
|
||||
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
|
||||
return rewriter.create<AffineLoadOp>(loc, input, reverseIvs);
|
||||
});
|
||||
// Transpose the elements by generating a load from the
|
||||
// reverse indices.
|
||||
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
|
||||
return builder.create<AffineLoadOp>(loc, input,
|
||||
reverseIvs);
|
||||
});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue