diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp index 0d35778d2d94..3097681ea3fa 100644 --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -50,15 +50,14 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc, } /// This defines the function type used to process an iteration of a lowered -/// loop. It takes as input a rewriter, an array of memRefOperands corresponding -/// to the operands of the input operation, and the set of loop induction -/// variables for the iteration. It returns a value to store at the current -/// index of the iteration. -using LoopIterationFn = function_ref memRefOperands, - ArrayRef loopIvs)>; +/// loop. It takes as input an OpBuilder, an range of memRefOperands +/// corresponding to the operands of the input operation, and the range of loop +/// induction variables for the iteration. It returns a value to store at the +/// current index of the iteration. +using LoopIterationFn = function_ref; -static void lowerOpToLoops(Operation *op, ArrayRef operands, +static void lowerOpToLoops(Operation *op, ValueRange operands, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = (*op->result_type_begin()).cast(); @@ -68,22 +67,21 @@ static void lowerOpToLoops(Operation *op, ArrayRef operands, auto memRefType = convertTensorToMemRef(tensorType); auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); - // Create an empty affine loop for each of the dimensions within the shape. - SmallVector loopIvs; - for (auto dim : tensorType.getShape()) { - auto loop = rewriter.create(loc, /*lb=*/0, dim, /*step=*/1); - loopIvs.push_back(loop.getInductionVar()); - - // Update the rewriter insertion point to the beginning of the loop. - rewriter.setInsertionPointToStart(loop.getBody()); - } - - // Generate a call to the processing function with the rewriter, the memref - // operands, and the loop induction variables. This function will return the - // value to store at the current index. - Value valueToStore = processIteration(rewriter, operands, loopIvs); - rewriter.create(loc, valueToStore, alloc, - llvm::makeArrayRef(loopIvs)); + // Create a nest of affine loops, with one loop per dimension of the shape. + // The buildAffineLoopNest function takes a callback that is used to construct + // the body of the innermost loop given a builder, a location and a range of + // loop induction variables. + SmallVector lowerBounds(tensorType.getRank(), /*Value=*/0); + SmallVector steps(tensorType.getRank(), /*Value=*/1); + buildAffineLoopNest( + rewriter, loc, lowerBounds, tensorType.getShape(), steps, + [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { + // Call the processing function with the rewriter, the memref operands, + // and the loop induction variables. This function will return the value + // to store at the current index. + Value valueToStore = processIteration(nestedBuilder, operands, ivs); + nestedBuilder.create(loc, valueToStore, alloc, ivs); + }); // Replace this operation with the generated alloc. rewriter.replaceOp(op, alloc); @@ -105,8 +103,8 @@ struct BinaryOpLowering : public ConversionPattern { auto loc = op->getLoc(); lowerOpToLoops( op, operands, rewriter, - [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, - ArrayRef loopIvs) { + [loc](OpBuilder &builder, ValueRange memRefOperands, + ValueRange loopIvs) { // Generate an adaptor for the remapped operands of the BinaryOp. This // allows for using the nice named accessors that are generated by the // ODS. @@ -115,12 +113,12 @@ struct BinaryOpLowering : public ConversionPattern { // Generate loads for the element of 'lhs' and 'rhs' at the inner // loop. auto loadedLhs = - rewriter.create(loc, binaryAdaptor.lhs(), loopIvs); + builder.create(loc, binaryAdaptor.lhs(), loopIvs); auto loadedRhs = - rewriter.create(loc, binaryAdaptor.rhs(), loopIvs); + builder.create(loc, binaryAdaptor.rhs(), loopIvs); // Create the binary operation performed on the loaded values. - return rewriter.create(loc, loadedLhs, loadedRhs); + return builder.create(loc, loadedLhs, loadedRhs); }); return success(); } @@ -227,21 +225,21 @@ struct TransposeOpLowering : public ConversionPattern { matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops( - op, operands, rewriter, - [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, - ArrayRef loopIvs) { - // Generate an adaptor for the remapped operands of the TransposeOp. - // This allows for using the nice named accessors that are generated - // by the ODS. - toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); - Value input = transposeAdaptor.input(); + lowerOpToLoops(op, operands, rewriter, + [loc](OpBuilder &builder, ValueRange memRefOperands, + ValueRange loopIvs) { + // Generate an adaptor for the remapped operands of the + // TransposeOp. This allows for using the nice named + // accessors that are generated by the ODS. + toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); + Value input = transposeAdaptor.input(); - // Transpose the elements by generating a load from the reverse - // indices. - SmallVector reverseIvs(llvm::reverse(loopIvs)); - return rewriter.create(loc, input, reverseIvs); - }); + // Transpose the elements by generating a load from the + // reverse indices. + SmallVector reverseIvs(llvm::reverse(loopIvs)); + return builder.create(loc, input, + reverseIvs); + }); return success(); } }; diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp index 7ee201785104..cac3415f48d6 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -50,13 +50,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc, } /// This defines the function type used to process an iteration of a lowered -/// loop. It takes as input a rewriter, an array of memRefOperands corresponding -/// to the operands of the input operation, and the set of loop induction -/// variables for the iteration. It returns a value to store at the current -/// index of the iteration. -using LoopIterationFn = function_ref memRefOperands, - ArrayRef loopIvs)>; +/// loop. It takes as input an OpBuilder, an range of memRefOperands +/// corresponding to the operands of the input operation, and the range of loop +/// induction variables for the iteration. It returns a value to store at the +/// current index of the iteration. +using LoopIterationFn = function_ref; static void lowerOpToLoops(Operation *op, ArrayRef operands, PatternRewriter &rewriter, @@ -68,22 +67,21 @@ static void lowerOpToLoops(Operation *op, ArrayRef operands, auto memRefType = convertTensorToMemRef(tensorType); auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); - // Create an empty affine loop for each of the dimensions within the shape. - SmallVector loopIvs; - for (auto dim : tensorType.getShape()) { - auto loop = rewriter.create(loc, /*lb=*/0, dim, /*step=*/1); - loopIvs.push_back(loop.getInductionVar()); - - // Update the rewriter insertion point to the beginning of the loop. - rewriter.setInsertionPointToStart(loop.getBody()); - } - - // Generate a call to the processing function with the rewriter, the memref - // operands, and the loop induction variables. This function will return the - // value to store at the current index. - Value valueToStore = processIteration(rewriter, operands, loopIvs); - rewriter.create(loc, valueToStore, alloc, - llvm::makeArrayRef(loopIvs)); + // Create a nest of affine loops, with one loop per dimension of the shape. + // The buildAffineLoopNest function takes a callback that is used to construct + // the body of the innermost loop given a builder, a location and a range of + // loop induction variables. + SmallVector lowerBounds(tensorType.getRank(), /*Value=*/0); + SmallVector steps(tensorType.getRank(), /*Value=*/1); + buildAffineLoopNest( + rewriter, loc, lowerBounds, tensorType.getShape(), steps, + [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { + // Call the processing function with the rewriter, the memref operands, + // and the loop induction variables. This function will return the value + // to store at the current index. + Value valueToStore = processIteration(nestedBuilder, operands, ivs); + nestedBuilder.create(loc, valueToStore, alloc, ivs); + }); // Replace this operation with the generated alloc. rewriter.replaceOp(op, alloc); @@ -105,8 +103,8 @@ struct BinaryOpLowering : public ConversionPattern { auto loc = op->getLoc(); lowerOpToLoops( op, operands, rewriter, - [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, - ArrayRef loopIvs) { + [loc](OpBuilder &builder, ValueRange memRefOperands, + ValueRange loopIvs) { // Generate an adaptor for the remapped operands of the BinaryOp. This // allows for using the nice named accessors that are generated by the // ODS. @@ -115,12 +113,12 @@ struct BinaryOpLowering : public ConversionPattern { // Generate loads for the element of 'lhs' and 'rhs' at the inner // loop. auto loadedLhs = - rewriter.create(loc, binaryAdaptor.lhs(), loopIvs); + builder.create(loc, binaryAdaptor.lhs(), loopIvs); auto loadedRhs = - rewriter.create(loc, binaryAdaptor.rhs(), loopIvs); + builder.create(loc, binaryAdaptor.rhs(), loopIvs); // Create the binary operation performed on the loaded values. - return rewriter.create(loc, loadedLhs, loadedRhs); + return builder.create(loc, loadedLhs, loadedRhs); }); return success(); } @@ -226,21 +224,21 @@ struct TransposeOpLowering : public ConversionPattern { matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops( - op, operands, rewriter, - [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, - ArrayRef loopIvs) { - // Generate an adaptor for the remapped operands of the TransposeOp. - // This allows for using the nice named accessors that are generated - // by the ODS. - toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); - Value input = transposeAdaptor.input(); + lowerOpToLoops(op, operands, rewriter, + [loc](OpBuilder &builder, ValueRange memRefOperands, + ValueRange loopIvs) { + // Generate an adaptor for the remapped operands of the + // TransposeOp. This allows for using the nice named + // accessors that are generated by the ODS. + toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); + Value input = transposeAdaptor.input(); - // Transpose the elements by generating a load from the reverse - // indices. - SmallVector reverseIvs(llvm::reverse(loopIvs)); - return rewriter.create(loc, input, reverseIvs); - }); + // Transpose the elements by generating a load from the + // reverse indices. + SmallVector reverseIvs(llvm::reverse(loopIvs)); + return builder.create(loc, input, + reverseIvs); + }); return success(); } }; diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp index 0d35778d2d94..3097681ea3fa 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -50,15 +50,14 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc, } /// This defines the function type used to process an iteration of a lowered -/// loop. It takes as input a rewriter, an array of memRefOperands corresponding -/// to the operands of the input operation, and the set of loop induction -/// variables for the iteration. It returns a value to store at the current -/// index of the iteration. -using LoopIterationFn = function_ref memRefOperands, - ArrayRef loopIvs)>; +/// loop. It takes as input an OpBuilder, an range of memRefOperands +/// corresponding to the operands of the input operation, and the range of loop +/// induction variables for the iteration. It returns a value to store at the +/// current index of the iteration. +using LoopIterationFn = function_ref; -static void lowerOpToLoops(Operation *op, ArrayRef operands, +static void lowerOpToLoops(Operation *op, ValueRange operands, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = (*op->result_type_begin()).cast(); @@ -68,22 +67,21 @@ static void lowerOpToLoops(Operation *op, ArrayRef operands, auto memRefType = convertTensorToMemRef(tensorType); auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); - // Create an empty affine loop for each of the dimensions within the shape. - SmallVector loopIvs; - for (auto dim : tensorType.getShape()) { - auto loop = rewriter.create(loc, /*lb=*/0, dim, /*step=*/1); - loopIvs.push_back(loop.getInductionVar()); - - // Update the rewriter insertion point to the beginning of the loop. - rewriter.setInsertionPointToStart(loop.getBody()); - } - - // Generate a call to the processing function with the rewriter, the memref - // operands, and the loop induction variables. This function will return the - // value to store at the current index. - Value valueToStore = processIteration(rewriter, operands, loopIvs); - rewriter.create(loc, valueToStore, alloc, - llvm::makeArrayRef(loopIvs)); + // Create a nest of affine loops, with one loop per dimension of the shape. + // The buildAffineLoopNest function takes a callback that is used to construct + // the body of the innermost loop given a builder, a location and a range of + // loop induction variables. + SmallVector lowerBounds(tensorType.getRank(), /*Value=*/0); + SmallVector steps(tensorType.getRank(), /*Value=*/1); + buildAffineLoopNest( + rewriter, loc, lowerBounds, tensorType.getShape(), steps, + [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { + // Call the processing function with the rewriter, the memref operands, + // and the loop induction variables. This function will return the value + // to store at the current index. + Value valueToStore = processIteration(nestedBuilder, operands, ivs); + nestedBuilder.create(loc, valueToStore, alloc, ivs); + }); // Replace this operation with the generated alloc. rewriter.replaceOp(op, alloc); @@ -105,8 +103,8 @@ struct BinaryOpLowering : public ConversionPattern { auto loc = op->getLoc(); lowerOpToLoops( op, operands, rewriter, - [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, - ArrayRef loopIvs) { + [loc](OpBuilder &builder, ValueRange memRefOperands, + ValueRange loopIvs) { // Generate an adaptor for the remapped operands of the BinaryOp. This // allows for using the nice named accessors that are generated by the // ODS. @@ -115,12 +113,12 @@ struct BinaryOpLowering : public ConversionPattern { // Generate loads for the element of 'lhs' and 'rhs' at the inner // loop. auto loadedLhs = - rewriter.create(loc, binaryAdaptor.lhs(), loopIvs); + builder.create(loc, binaryAdaptor.lhs(), loopIvs); auto loadedRhs = - rewriter.create(loc, binaryAdaptor.rhs(), loopIvs); + builder.create(loc, binaryAdaptor.rhs(), loopIvs); // Create the binary operation performed on the loaded values. - return rewriter.create(loc, loadedLhs, loadedRhs); + return builder.create(loc, loadedLhs, loadedRhs); }); return success(); } @@ -227,21 +225,21 @@ struct TransposeOpLowering : public ConversionPattern { matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops( - op, operands, rewriter, - [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, - ArrayRef loopIvs) { - // Generate an adaptor for the remapped operands of the TransposeOp. - // This allows for using the nice named accessors that are generated - // by the ODS. - toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); - Value input = transposeAdaptor.input(); + lowerOpToLoops(op, operands, rewriter, + [loc](OpBuilder &builder, ValueRange memRefOperands, + ValueRange loopIvs) { + // Generate an adaptor for the remapped operands of the + // TransposeOp. This allows for using the nice named + // accessors that are generated by the ODS. + toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); + Value input = transposeAdaptor.input(); - // Transpose the elements by generating a load from the reverse - // indices. - SmallVector reverseIvs(llvm::reverse(loopIvs)); - return rewriter.create(loc, input, reverseIvs); - }); + // Transpose the elements by generating a load from the + // reverse indices. + SmallVector reverseIvs(llvm::reverse(loopIvs)); + return builder.create(loc, input, + reverseIvs); + }); return success(); } };