diff --git a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h index a1c94f19d3f2..973b995f10bb 100644 --- a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h +++ b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h @@ -21,9 +21,9 @@ namespace mlir { class AffineForOp; struct LogicalResult; -namespace linalg { +namespace loop { class ForOp; -} +} // end namespace loop /// Convert a perfect affine loop nest with the outermost loop identified by /// `forOp` into a gpu::Launch operation. Map `numBlockDims` outer loops to @@ -49,9 +49,9 @@ LogicalResult convertAffineLoopNestToGPULaunch(AffineForOp forOp, /// parallelization is performed, it is under the responsibility of the caller /// to strip-mine the loops and to perform the dependence analysis before /// calling the conversion. -LogicalResult convertLinalgLoopNestToGPULaunch(linalg::ForOp forOp, - unsigned numBlockDims, - unsigned numThreadDims); +LogicalResult convertLoopNestToGPULaunch(loop::ForOp forOp, + unsigned numBlockDims, + unsigned numThreadDims); } // namespace mlir #endif // MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_ diff --git a/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td b/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td index 05ca9ce66dea..a3796d2e2df5 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td +++ b/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td @@ -98,9 +98,9 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> { Usage: linalg.copy(%arg0, %arg1) : !linalg.view, !linalg.view - One possible lowering to affine form is: + One possible lowering to loop form is: %0 = linalg.dim %arg0, 0 : index - linalg.for %i0 = %c0 to %0 step %c1 { + loop.for %i0 = %c0 to %0 step %c1 { %1 = linalg.load %arg0[%i0] : !linalg.view linalg.store %1, %arg1[%i0] : !linalg.view } @@ -113,13 +113,13 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> { outputPermutation : (i, j, k) -> (k, j, i)} : !linalg.view, !linalg.view - One possible lowering to affine form is: + One possible lowering to loop form is: %0 = linalg.dim %arg0, 0 %1 = linalg.dim %arg0, 1 %2 = linalg.dim %arg0, 2 - linalg.for %i0 = %c0 to %{{.*}} step %c1 { - linalg.for %i1 = %c0 to %{{.*}} step %c1 { - linalg.for %i2 = %c0 to %{{.*}} step %c1 { + loop.for %i0 = %c0 to %{{.*}} step %c1 { + loop.for %i1 = %c0 to %{{.*}} step %c1 { + loop.for %i2 = %c0 to %{{.*}} step %c1 { %3 = linalg.load %arg0[%i0, %i2, %i1] : !linalg.view linalg.store %3, %arg1[%i2, %i1, %i0] : !linalg.view diff --git a/mlir/include/mlir/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Linalg/IR/LinalgOps.h index 053d20a376a3..9a167662d38b 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Linalg/IR/LinalgOps.h @@ -29,82 +29,6 @@ class OperationFolder; namespace linalg { -/// The "linalg.for" operation represents a loop nest taking 3 SSA value as -/// operands that represent the lower bound, upper bound and step respectively. -/// The operation defines an SSA value for its induction variable. It has one -/// region capturing the loop body. The induction variable is represented as an -/// argument of this region. This SSA value always has type index, which is the -/// size of the machine word. The step is a value of type index, required to be -/// positive. -/// The lower and upper bounds specify a half-open range: the range includes the -/// lower bound but does not include the upper bound. -/// -/// The body region must contain exactly one block that terminates with -/// "linalg.terminator". Calling linalg::ForOp::build will create such region -/// and insert the terminator, so will the parsing even in cases if it is absent -/// from the custom format. For example: -/// -/// ```mlir -/// linalg.for %iv = %lb to %ub step %step { -/// ... // body -/// } -/// ``` -class ForOp - : public Op::Impl, OpTrait::ZeroResult> { -public: - using Op::Op; - - // Hooks to customize behavior of this op. - static void build(Builder *builder, OperationState *result, Value *lb, - Value *ub, Value *step); - LogicalResult verify(); - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); - - static StringRef getOperationName() { return "linalg.for"; } - - /// Return a Builder set up to insert operations immediately before the - /// terminator. - OpBuilder getBodyBuilder() { - Block *body = getBody(); - return OpBuilder(body, std::prev(body->end())); - } - - /// Get the body of the ForOp. - Block *getBody() { return &getRegion().front(); } - - /// Get the body region of the ForOp. - Region &getRegion() { return getOperation()->getRegion(0); } - - /// Returns the induction variable for this loop. - Value *getInductionVar() { return getBody()->getArgument(0); } - - //===--------------------------------------------------------------------===// - // Bounds and step - //===--------------------------------------------------------------------===// - /// Returns the lower bound operand. - Value *getLowerBound() { return getOperand(0); } - - /// Returns the upper bound operand. - Value *getUpperBound() { return getOperand(1); } - - /// Returns loop step. - Value *getStep() { return getOperand(2); } - - /// Set lower bound. - void setLowerBound(Value *lb) { setOperand(0, lb); } - - /// Set upper bound. - void setUpperBound(Value *ub) { setOperand(1, ub); } - - /// Set loop step. - void setStep(Value *step) { setOperand(2, step); } -}; - -/// Returns the loop parent of an induction variable. If the provided value is -/// not an induction variable, then return nullptr. -ForOp getForInductionVarOwner(Value *val); - /// A linalg.LoadOp is the counterpart of load but operating on ViewType /// instead of MemRefType. /// diff --git a/mlir/include/mlir/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Linalg/IR/LinalgOps.td index 49b75be48818..6bf39ee01d8b 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Linalg/IR/LinalgOps.td @@ -171,28 +171,6 @@ def RangeIntersectOp : Linalg_Op<"range_intersect", [NoSideEffect]>, }]>]; } -def TerminatorOp : - Linalg_Op<"terminator", [NativeOpTrait<"IsTerminator">]> { - let summary = "linalg terminator operation"; - let description = [{ - "linalg.terminator" is a special terminator operation for blocks inside - linalg loops and branches. It unconditionally transmits the control flow to - the successor of the operation enclosing the region. - - This operation does _not_ have a custom syntax. However, linalg control - operations omit the terminator in their custom syntax for brevity. - - linalg.terminator - }]; - - // No custom parsing/printing form. - let parser = ?; - let printer = ?; - - // Fully specified by traits. - let verifier = ?; -} - def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>, Arguments<(ins View:$view, Variadic:$ranges)>, Results<(outs View)> { diff --git a/mlir/include/mlir/Linalg/Utils/Utils.h b/mlir/include/mlir/Linalg/Utils/Utils.h index 53a57fe0a5a9..1c0335985d74 100644 --- a/mlir/include/mlir/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Linalg/Utils/Utils.h @@ -18,6 +18,7 @@ #ifndef MLIR_LINALG_UTILS_H_ #define MLIR_LINALG_UTILS_H_ +#include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/EDSC/Helpers.h" #include "mlir/Linalg/IR/LinalgOps.h" #include "mlir/Support/LLVM.h" @@ -26,15 +27,16 @@ namespace mlir { class AffineExpr; class AffineMap; class OperationFolder; + namespace edsc { -/// A LoopRangeBuilder is a generic NestedBuilder for linalg.for operations. +/// A LoopRangeBuilder is a generic NestedBuilder for loop.for operations. /// More specifically it is meant to be used as a temporary object for /// representing any nested MLIR construct that is "related to" an mlir::Value* /// (for now an induction variable). class LoopRangeBuilder : public NestedBuilder { public: - /// Constructs a new linalg::ForOp and captures the associated induction + /// Constructs a new loop.for and captures the associated induction /// variable. A ValueHandle pointer is passed as the first argument and is the /// *only* way to capture the loop induction variable. LoopRangeBuilder(ValueHandle *iv, ValueHandle range); @@ -53,9 +55,9 @@ public: ValueHandle operator()(std::function fun = nullptr); }; -/// Helper class to sugar building linalg.for loop nests from ranges. +/// Helper class to sugar building loop.for loop nests from ranges. /// This is similar to edsc::LoopNestBuilder except it works on ranges directly. -/// In the current implementation it produces linalg.for operations. +/// In the current implementation it produces loop.for operations. class LoopNestRangeBuilder { public: LoopNestRangeBuilder(llvm::ArrayRef ivs, @@ -88,7 +90,7 @@ SmallVector applyMapToValues(OpBuilder &b, Location loc, struct TiledLinalgOp { LinalgOp op; - SmallVector loops; + SmallVector loops; }; /// Performs standalone tiling of a single LinalgOp by `tileSizes`. diff --git a/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp b/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp index f5d0ccef2fa1..5064bbaab675 100644 --- a/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp +++ b/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp @@ -15,8 +15,8 @@ // limitations under the License. // ============================================================================= // -// This file implements a pass to convert std.for, std.if and std.terminator ops -// into standard CFG ops. +// This file implements a pass to convert loop.for, loop.if and loop.terminator +// ops into standard CFG ops. // //===----------------------------------------------------------------------===// @@ -54,7 +54,7 @@ struct ControlFlowToCFGPass : public FunctionPass { // first/last blocks in the parent region. The original loop operation is // replaced by the initialization operations that set up the initial value of // the loop induction variable (%iv) and computes the loop bounds that are loop- -// invariant for affine loops. The operations following the original std.for +// invariant for affine loops. The operations following the original loop.for // are split out into a separate continuation (exit) block. A condition block is // created before the continuation block. It checks the exit condition of the // loop and branches either to the continuation block, or to the first block of @@ -108,14 +108,14 @@ struct ForLowering : public ConversionPattern { PatternRewriter &rewriter) const override; }; -// Create a CFG subgraph for the std.if operation (including its "then" and +// Create a CFG subgraph for the loop.if operation (including its "then" and // optional "else" operation blocks). We maintain the invariants that the // subgraph has a single entry and a single exit point, and that the entry/exit // blocks are respectively the first/last block of the enclosing region. The -// operations following the std.if are split into a continuation (subgraph +// operations following the loop.if are split into a continuation (subgraph // exit) block. The condition is lowered to a chain of blocks that implement the // short-circuit scheme. Condition blocks are created by splitting out an empty -// block from the block that contains the std.if operation. They +// block from the block that contains the loop.if operation. They // conditionally branch to either the first block of the "then" region, or to // the first block of the "else" region. If the latter is absent, they branch // to the continuation block instead. The last blocks of "then" and "else" @@ -232,14 +232,14 @@ IfLowering::matchAndRewrite(Operation *op, ArrayRef operands, auto ifOp = cast(op); auto loc = op->getLoc(); - // Start by splitting the block containing the 'std.if' into two parts. + // Start by splitting the block containing the 'loop.if' into two parts. // The part before will contain the condition, the part after will be the // continuation point. auto *condBlock = rewriter.getInsertionBlock(); auto opPosition = rewriter.getInsertionPoint(); auto *continueBlock = rewriter.splitBlock(condBlock, opPosition); - // Move blocks from the "then" region to the region containing 'std.if', + // Move blocks from the "then" region to the region containing 'loop.if', // place it before the continuation block, and branch to it. auto &thenRegion = ifOp.thenRegion(); auto *thenBlock = &thenRegion.front(); @@ -248,7 +248,7 @@ IfLowering::matchAndRewrite(Operation *op, ArrayRef operands, rewriter.inlineRegionBefore(thenRegion, continueBlock); // Move blocks from the "else" region (if present) to the region containing - // 'std.if', place it before the continuation block and branch to it. It + // 'loop.if', place it before the continuation block and branch to it. It // will be placed after the "then" regions. auto *elseBlock = continueBlock; auto &elseRegion = ifOp.elseRegion(); diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp index 96ac947a1e03..ac164ab816fe 100644 --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp @@ -23,10 +23,10 @@ #include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h" #include "mlir/AffineOps/AffineOps.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/GPU/GPUDialect.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" -#include "mlir/Linalg/IR/LinalgOps.h" #include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/LowerAffine.h" #include "mlir/Transforms/RegionUtils.h" @@ -36,6 +36,7 @@ #define DEBUG_TYPE "loops-to-gpu" using namespace mlir; +using namespace mlir::loop; // Extract an indexed value from KernelDim3. static Value *getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) { @@ -56,8 +57,8 @@ static Value *getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) { static Operation::operand_range getLowerBoundOperands(AffineForOp forOp) { return forOp.getLowerBoundOperands(); } -static SmallVector getLowerBoundOperands(linalg::ForOp forOp) { - SmallVector bounds(1, forOp.getLowerBound()); +static SmallVector getLowerBoundOperands(ForOp forOp) { + SmallVector bounds(1, forOp.lowerBound()); return bounds; } @@ -65,8 +66,8 @@ static SmallVector getLowerBoundOperands(linalg::ForOp forOp) { static Operation::operand_range getUpperBoundOperands(AffineForOp forOp) { return forOp.getUpperBoundOperands(); } -static SmallVector getUpperBoundOperands(linalg::ForOp forOp) { - SmallVector bounds(1, forOp.getUpperBound()); +static SmallVector getUpperBoundOperands(ForOp forOp) { + SmallVector bounds(1, forOp.upperBound()); return bounds; } @@ -75,17 +76,15 @@ static SmallVector getUpperBoundOperands(linalg::ForOp forOp) { static Value *getOrCreateStep(AffineForOp forOp, OpBuilder &builder) { return builder.create(forOp.getLoc(), forOp.getStep()); } -static Value *getOrCreateStep(linalg::ForOp forOp, OpBuilder &) { - return forOp.getStep(); -} +static Value *getOrCreateStep(ForOp forOp, OpBuilder &) { return forOp.step(); } // Get a Value for the loop lower bound. If the value requires computation, // materialize the instructions using builder. static Value *getOrEmitLowerBound(AffineForOp forOp, OpBuilder &builder) { return lowerAffineLowerBound(forOp, builder); } -static Value *getOrEmitLowerBound(linalg::ForOp forOp, OpBuilder &) { - return forOp.getLowerBound(); +static Value *getOrEmitLowerBound(ForOp forOp, OpBuilder &) { + return forOp.lowerBound(); } // Get a Value for the loop upper bound. If the value requires computation, @@ -93,10 +92,16 @@ static Value *getOrEmitLowerBound(linalg::ForOp forOp, OpBuilder &) { static Value *getOrEmitUpperBound(AffineForOp forOp, OpBuilder &builder) { return lowerAffineUpperBound(forOp, builder); } -static Value *getOrEmitUpperBound(linalg::ForOp forOp, OpBuilder &) { - return forOp.getUpperBound(); +static Value *getOrEmitUpperBound(ForOp forOp, OpBuilder &) { + return forOp.upperBound(); } +// TODO(ntv): uniformize back once AffineForOp is in ODS. +static Region &getRegion(ForOp op) { return op.region(); } +static Region &getRegion(AffineForOp op) { return op.getRegion(); } +static Block *getBody(ForOp op) { return op.body(); } +static Block *getBody(AffineForOp op) { return op.getBody(); } + // Check the structure of the loop nest: // - there are enough loops to map to numBlockDims + numThreadDims; // - the loops are perfectly nested; @@ -122,9 +127,9 @@ LogicalResult checkLoopNestMappable(OpTy forOp, unsigned numBlockDims, } OpTy currentLoop = forOp; - Region &limit = forOp.getRegion(); + Region &limit = getRegion(forOp); for (unsigned i = 0, e = numBlockDims + numThreadDims; i < e; ++i) { - Operation *nested = ¤tLoop.getBody()->front(); + Operation *nested = &getBody(currentLoop)->front(); if (!areValuesDefinedAbove(getLowerBoundOperands(currentLoop), limit) || !areValuesDefinedAbove(getUpperBoundOperands(currentLoop), limit)) return currentLoop.emitError( @@ -136,9 +141,9 @@ LogicalResult checkLoopNestMappable(OpTy forOp, unsigned numBlockDims, if (i == e - 1) break; - auto begin = currentLoop.getBody()->begin(), - end = currentLoop.getBody()->end(); - if (currentLoop.getBody()->empty() || std::next(begin, 2) != end) + auto begin = getBody(currentLoop)->begin(), + end = getBody(currentLoop)->end(); + if (getBody(currentLoop)->empty() || std::next(begin, 2) != end) return currentLoop.emitError( "expected perfectly nested loops in the body"); @@ -211,7 +216,7 @@ Optional LoopToGpuConverter::collectBounds(OpTy forOp, steps.push_back(step); if (i != numLoops - 1) - currentLoop = cast(¤tLoop.getBody()->front()); + currentLoop = cast(&getBody(currentLoop)->front()); } return currentLoop; } @@ -243,7 +248,7 @@ void LoopToGpuConverter::createLaunch(OpTy rootForOp, OpTy innermostForOp, // Still assuming perfect nesting so there are no values other than induction // variables that are defined in one loop and used in deeper loops. llvm::SetVector valuesToForwardSet; - getUsedValuesDefinedAbove(innermostForOp.getRegion(), rootForOp.getRegion(), + getUsedValuesDefinedAbove(getRegion(innermostForOp), getRegion(rootForOp), valuesToForwardSet); auto valuesToForward = valuesToForwardSet.takeVector(); auto originallyForwardedValues = valuesToForward.size(); @@ -258,14 +263,14 @@ void LoopToGpuConverter::createLaunch(OpTy rootForOp, OpTy innermostForOp, // gpu return and move the operations from the loop body block to the gpu // launch body block. Do not move the entire block because of the difference // in block arguments. - Operation &terminator = innermostForOp.getBody()->back(); + Operation &terminator = getBody(innermostForOp)->back(); Location terminatorLoc = terminator.getLoc(); terminator.erase(); - builder.setInsertionPointToEnd(innermostForOp.getBody()); + builder.setInsertionPointToEnd(getBody(innermostForOp)); builder.create(terminatorLoc); launchOp.getBody().front().getOperations().splice( launchOp.getBody().front().begin(), - innermostForOp.getBody()->getOperations()); + getBody(innermostForOp)->getOperations()); // Remap the loop iterators to use block/thread identifiers instead. Loops // may iterate from LB with step S whereas GPU thread/block ids always iterate @@ -328,11 +333,11 @@ static LogicalResult convertLoopNestToGPULaunch(OpTy forOp, LogicalResult mlir::convertAffineLoopNestToGPULaunch(AffineForOp forOp, unsigned numBlockDims, unsigned numThreadDims) { - return convertLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims); + return ::convertLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims); } -LogicalResult mlir::convertLinalgLoopNestToGPULaunch(linalg::ForOp forOp, - unsigned numBlockDims, - unsigned numThreadDims) { - return convertLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims); +LogicalResult mlir::convertLoopNestToGPULaunch(ForOp forOp, + unsigned numBlockDims, + unsigned numThreadDims) { + return ::convertLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims); } diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp index 13e4171033e7..7c785b5c9953 100644 --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp @@ -18,7 +18,7 @@ #include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h" #include "mlir/AffineOps/AffineOps.h" #include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h" -#include "mlir/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Pass/Pass.h" #include "llvm/Support/CommandLine.h" @@ -26,6 +26,7 @@ #define PASS_NAME "convert-loops-to-gpu" using namespace mlir; +using namespace mlir::loop; static llvm::cl::OptionCategory clOptionsCategory(PASS_NAME " options"); static llvm::cl::opt @@ -52,9 +53,9 @@ struct ForLoopMapper : public FunctionPass { if (failed(convertAffineLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims))) signalPassFailure(); - } else if (auto forOp = dyn_cast(&op)) { - if (failed(convertLinalgLoopNestToGPULaunch(forOp, numBlockDims, - numThreadDims))) + } else if (auto forOp = dyn_cast(&op)) { + if (failed(convertLoopNestToGPULaunch(forOp, numBlockDims, + numThreadDims))) signalPassFailure(); } } diff --git a/mlir/lib/Linalg/IR/LinalgOps.cpp b/mlir/lib/Linalg/IR/LinalgOps.cpp index 4d8b17940ec1..fa1a31586af8 100644 --- a/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -20,6 +20,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/EDSC/Helpers.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -37,120 +38,6 @@ using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; -//////////////////////////////////////////////////////////////////////////////// -// ForOp. -//////////////////////////////////////////////////////////////////////////////// -// Check that if a "block" has a terminator, it is an `TerminatorOp`. -static LogicalResult checkHasTerminator(OpState &op, Block &block) { - if (block.empty() || isa(block.back())) - return success(); - - return op.emitOpError("expects regions to end with '" + - linalg::TerminatorOp::getOperationName() + "'") - .attachNote() - << "in custom textual format, the absence of terminator implies '" - << linalg::TerminatorOp::getOperationName() << "'"; -} - -// Insert `linalg.terminator` at the end of the ForOp only region's only block -// if it does not have a terminator already. If a new `linalg.terminator` is -// inserted, the location is specified by `loc`. If the region is empty, insert -// a new block first. -static void ensureTerminator(Region ®ion, Builder &builder, Location loc) { - impl::ensureRegionTerminator(region, builder, loc); -} - -void mlir::linalg::ForOp::build(Builder *builder, OperationState *result, - Value *lb, Value *ub, Value *step) { - result->addOperands({lb, ub, step}); - Region *bodyRegion = result->addRegion(); - Block *body = new Block(); - body->addArgument(IndexType::get(builder->getContext())); - bodyRegion->push_back(body); - ensureTerminator(*bodyRegion, *builder, result->location); -} - -LogicalResult mlir::linalg::ForOp::verify() { - if (!getLowerBound()->getType().isa()) - return emitOpError("lower bound operand must be an index"); - if (!getUpperBound()->getType().isa()) - return emitOpError("upper bound operand must be an index"); - if (!getStep()->getType().dyn_cast()) - return emitOpError("step operand must be an index"); - if (auto cst = dyn_cast_or_null(getStep()->getDefiningOp())) - if (cst.getValue() <= 0) - return emitOpError("constant step operand must be positive"); - - if (std::next(getOperation()->getRegions().begin()) != - getOperation()->getRegions().end()) - return emitOpError("operation expected to have exactly one region"); - - auto &bodyRegion = getOperation()->getRegion(0); - // The body region must contain a single basic block. - if (bodyRegion.empty() || std::next(bodyRegion.begin()) != bodyRegion.end()) - return emitOpError("expected body region to have a single block"); - // Check that the body defines as single block argument for the induction - // variable. - auto *body = getBody(); - if (body->getNumArguments() != 1 || - !body->getArgument(0)->getType().isIndex()) - return emitOpError("expected body to have a single index argument for " - "the induction variable"); - if (failed(checkHasTerminator(*this, *body))) - return failure(); - return success(); -} - -void mlir::linalg::ForOp::print(OpAsmPrinter *p) { - *p << getOperationName() << " " << *getInductionVar() << " = " - << *getLowerBound() << " to " << *getUpperBound() << " step " - << *getStep(); - p->printRegion(getRegion(), - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); - p->printOptionalAttrDict(getAttrs()); -} - -ParseResult mlir::linalg::ForOp::parse(OpAsmParser *parser, - OperationState *result) { - auto &builder = parser->getBuilder(); - OpAsmParser::OperandType inductionVariable, lb, ub, step; - // Parse the induction variable followed by '='. - if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual()) - return failure(); - - // Parse loop bounds. - Type indexType = builder.getIndexType(); - if (parser->parseOperand(lb) || - parser->resolveOperand(lb, indexType, result->operands) || - parser->parseKeyword("to") || parser->parseOperand(ub) || - parser->resolveOperand(ub, indexType, result->operands) || - parser->parseKeyword("step") || parser->parseOperand(step) || - parser->resolveOperand(step, indexType, result->operands)) - return failure(); - - // Parse the body region. - Region *body = result->addRegion(); - if (parser->parseRegion(*body, inductionVariable, indexType)) - return failure(); - - ensureTerminator(*body, builder, result->location); - - // Parse the optional attribute list. - if (parser->parseOptionalAttributeDict(result->attributes)) - return failure(); - - return success(); -} - -mlir::linalg::ForOp mlir::linalg::getForInductionVarOwner(Value *val) { - auto *ivArg = dyn_cast(val); - if (!ivArg) - return ForOp(); - assert(ivArg->getOwner() && "unlinked block argument"); - return dyn_cast(ivArg->getOwner()->getContainingOp()); -} - //////////////////////////////////////////////////////////////////////////////// // LoadOp. //////////////////////////////////////////////////////////////////////////////// @@ -993,7 +880,7 @@ void mlir::linalg::emitScalarImplementation( OpBuilder b(linalgOp.getOperation()); auto nLoops = nPar + nRed + nWin; if (nLoops > 0) { - auto innermostLoop = linalg::getForInductionVarOwner(allIvs.back()); + auto innermostLoop = loop::getForInductionVarOwner(allIvs.back()); // accounts for linalg.terminator in loop. b = innermostLoop.getBodyBuilder(); } diff --git a/mlir/lib/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Linalg/IR/LinalgTypes.cpp index e5b2faf12578..61acbce7b024 100644 --- a/mlir/lib/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Linalg/IR/LinalgTypes.cpp @@ -35,7 +35,7 @@ using namespace mlir::linalg; mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { addTypes(); - addOperations(); + addOperations(); addOperations< #define GET_OP_LIST #include "mlir/Linalg/IR/LinalgOps.cpp.inc" diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index b2e964edfbc0..298f97812546 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -15,6 +15,7 @@ // limitations under the License. // ============================================================================= +#include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/EDSC/Builders.h" @@ -746,72 +747,17 @@ static void lowerLinalgSubViewOps(FuncOp &f) { }); } -// Converts a `linalg.for` op to CFG form before actual conversion to the LLVM -// dialect starts. -static void lowerLinalgForToCFG(FuncOp &f) { - // Collect all the For operations. We do this as a prepass to avoid - // invalidating the walker with our rewrite. - SmallVector instsToRewrite; - f.walk( - [&](linalg::ForOp op) { instsToRewrite.push_back(op); }); - - for (auto forOp : llvm::reverse(instsToRewrite)) { - auto *op = forOp.getOperation(); - auto loc = op->getLoc(); - using namespace edsc::op; - OpBuilder builder(op); - ScopedContext scope(builder, loc); - ValueHandle lb(forOp.getLowerBound()), ub(forOp.getUpperBound()), - step(forOp.getStep()); - - // 1. Split Block into init and end blocks, create body and condition blocks - // with the `iv` block argument. - auto *initBlock = op->getBlock(); - auto *endBlock = initBlock->splitBlock(op); - BlockHandle conditionBlock, bodyBlock; - ValueHandle iv(IndexType::get(op->getContext())); - BlockBuilder(&conditionBlock, {&iv})(); - BlockBuilder(&bodyBlock, {})(); - - // 2. Create and fill the condition block whose sole purpose is to evaluate - // iv and branch to either `bodyBlock` or `endBlock`. Add all branches to - // the `conditionBlock`. - // clang-format off - BlockBuilder(conditionBlock, Append())([&] { - auto cmp = cmpi(CmpIPredicate::SGT, ub, iv); - cond_br(cmp, bodyBlock, {}, endBlock, {}); - }); - BlockBuilder(bodyBlock, Append())([&] { - br(conditionBlock, addi(iv, step)); - }); - BlockBuilder(initBlock, Append())([&] { - br(conditionBlock, lb); - }); - // clang-format on - - // 3. Move the instructions from the for loop to the body, update all uses - // of the induction variable and clean up. - auto *oldBody = forOp.getBody(); - bodyBlock.getBlock()->getOperations().splice( - bodyBlock.getBlock()->begin(), oldBody->getOperations(), - oldBody->begin(), std::prev(oldBody->end())); - forOp.getInductionVar()->replaceAllUsesWith(iv); - forOp.erase(); - } -} - void LowerLinalgToLLVMPass::runOnModule() { auto module = getModule(); - for (auto f : module.getOps()) { + for (auto f : module.getOps()) lowerLinalgSubViewOps(f); - lowerLinalgForToCFG(f); - } // Convert to the LLVM IR dialect using the converter defined above. OwningRewritePatternList patterns; LinalgTypeConverter converter(&getContext()); populateAffineToStdConversionPatterns(patterns, &getContext()); + populateLoopToStdConversionPatterns(patterns, &getContext()); populateStdToLLVMConversionPatterns(converter, patterns); populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index f5d411f9ed5e..b66733356790 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -19,6 +19,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/EDSC/Helpers.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExprVisitor.h" @@ -41,6 +42,7 @@ using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; using namespace mlir::linalg::intrinsics; +using namespace mlir::loop; #define DEBUG_TYPE "linalg-tiling" @@ -444,7 +446,7 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef tileSizes, SmallVector loops; loops.reserve(ivs.size()); for (auto iv : ivs) - loops.push_back(linalg::getForInductionVarOwner(iv)); + loops.push_back(loop::getForInductionVarOwner(iv)); return TiledLinalgOp{res, loops}; } diff --git a/mlir/lib/Linalg/Utils/Utils.cpp b/mlir/lib/Linalg/Utils/Utils.cpp index 85b64db6fdc8..3bbfc9b175a4 100644 --- a/mlir/lib/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Linalg/Utils/Utils.cpp @@ -20,6 +20,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Linalg/Utils/Utils.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/EDSC/Helpers.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -29,6 +30,7 @@ #include "mlir/Linalg/Passes.h" #include "mlir/Linalg/Utils/Intrinsics.h" #include "mlir/Pass/Pass.h" +#include "mlir/StandardOps/Ops.h" #include "mlir/Support/STLExtras.h" #include "mlir/Transforms/FoldUtils.h" @@ -37,6 +39,7 @@ using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; using namespace mlir::linalg::intrinsics; +using namespace mlir::loop; mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(ValueHandle *iv, ValueHandle range) { @@ -47,18 +50,18 @@ mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(ValueHandle *iv, auto lb = rangeOp.min(); auto ub = rangeOp.max(); auto step = rangeOp.step(); - auto forOp = OperationHandle::createOp(lb, ub, step); + auto forOp = OperationHandle::createOp(lb, ub, step); *iv = ValueHandle(forOp.getInductionVar()); - auto *body = forOp.getBody(); + auto *body = forOp.body(); enter(body, /*prev=*/1); } mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(ValueHandle *iv, SubViewOp::Range range) { - auto forOp = OperationHandle::createOp(range.min, range.max, - range.step); + auto forOp = + OperationHandle::createOp(range.min, range.max, range.step); *iv = ValueHandle(forOp.getInductionVar()); - auto *body = forOp.getBody(); + auto *body = forOp.body(); enter(body, /*prev=*/1); } diff --git a/mlir/test/Conversion/LoopsToGPU/linalg_to_gpu.mlir b/mlir/test/Conversion/LoopsToGPU/linalg_to_gpu.mlir index ca7871ec6bea..3bec95506293 100644 --- a/mlir/test/Conversion/LoopsToGPU/linalg_to_gpu.mlir +++ b/mlir/test/Conversion/LoopsToGPU/linalg_to_gpu.mlir @@ -7,10 +7,10 @@ func @foo(%arg0: !linalg.buffer, %arg1 : index) { %c3 = constant 3 : index // CHECK: subi %{{.*}}, %{{.*}} : index // CHECK-NEXT: %[[range_i:.*]] = divis {{.*}}, %{{.*}} : index - linalg.for %i0 = %c0 to %c42 step %c3 { + loop.for %i0 = %c0 to %c42 step %c3 { // CHECK: subi %{{.*}}, %{{.*}} : index // CHECK-NEXT: %[[range_j:.*]] = divis {{.*}}, %{{.*}} : index - linalg.for %i1 = %c3 to %c42 step %arg1 { + loop.for %i1 = %c3 to %c42 step %arg1 { // CHECK: gpu.launch // CHECK-SAME: blocks // CHECK-SAME: threads diff --git a/mlir/test/Linalg/fusion.mlir b/mlir/test/Linalg/fusion.mlir index 73258aa72c84..24e078dcffc4 100644 --- a/mlir/test/Linalg/fusion.mlir +++ b/mlir/test/Linalg/fusion.mlir @@ -10,13 +10,13 @@ func @f1(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< } // No RAW dependences, the pass does not fuse RAR atm. // FUSE-0-LABEL: func @f1 -// FUSE-0-NOT: linalg.for +// FUSE-0-NOT: loop.for // FUSE-2-LABEL: func @f1 -// FUSE-2-NOT: linalg.for +// FUSE-2-NOT: loop.for // FUSE-23-LABEL: func @f1 -// FUSE-23-NOT: linalg.for +// FUSE-23-NOT: loop.for // FUSE-234-LABEL: func @f1 -// FUSE-234-NOT: linalg.for +// FUSE-234-NOT: loop.for func @f2(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view, %D: !linalg.view, %E: !linalg.view) -> !linalg.view { linalg.matmul(%A, %B, %C) : !linalg.view, !linalg.view, !linalg.view @@ -25,19 +25,19 @@ func @f2(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< } // No tiling => no fusion // FUSE-0-LABEL: func @f2 -// FUSE-0-NOT: linalg.for +// FUSE-0-NOT: loop.for // // FUSE-2-LABEL: func @f2 // FUSE-2: %[[C_0:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view -// FUSE-2: linalg.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { +// FUSE-2: loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { // FUSE-2: linalg.matmul // FUSE-2: linalg.matmul // // FUSE-23-LABEL: func @f2 // FUSE-23: %[[C_0:.*]] = linalg.dim %arg2, 0 : !linalg.view // FUSE-23: %[[D_1:.*]] = linalg.dim %arg3, 1 : !linalg.view -// FUSE-23: linalg.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { -// FUSE-23: linalg.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { +// FUSE-23: loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { +// FUSE-23: loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { // FUSE-23: linalg.matmul // FUSE-23: linalg.matmul // @@ -45,9 +45,9 @@ func @f2(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< // FUSE-234: %[[C_0:.*]] = linalg.dim %arg2, 0 : !linalg.view // FUSE-234: %[[C_1:.*]] = linalg.dim %arg2, 1 : !linalg.view // FUSE-234: %[[D_1:.*]] = linalg.dim %arg3, 1 : !linalg.view -// FUSE-234: linalg.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { -// FUSE-234: linalg.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { -// FUSE-234: linalg.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { +// FUSE-234: loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { +// FUSE-234: loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { +// FUSE-234: loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { // FUSE-234: linalg.matmul // FUSE-234: linalg.matmul @@ -58,17 +58,17 @@ func @f3(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< } // No tiling => no fusion // FUSE-0-LABEL: func @f3 -// FUSE-0-NOT: linalg.for +// FUSE-0-NOT: loop.for // // Read to %C does not get tiled along 1st dimension => no fusion // FUSE-2-LABEL: func @f3 -// FUSE-2-NOT: linalg.for +// FUSE-2-NOT: loop.for // // FUSE-23-LABEL: func @f3 // FUSE-23: %[[D_0:.*]] = linalg.dim %arg3, 0 : !linalg.view // FUSE-23: %[[C_1:.*]] = linalg.dim %arg2, 1 : !linalg.view -// FUSE-23: linalg.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { -// FUSE-23: linalg.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { +// FUSE-23: loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { +// FUSE-23: loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { // FUSE-23: linalg.matmul // FUSE-23: linalg.matmul // @@ -76,9 +76,9 @@ func @f3(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< // FUSE-234: %[[D_0:.*]] = linalg.dim %arg3, 0 : !linalg.view // FUSE-234: %[[D_1:.*]] = linalg.dim %arg3, 1 : !linalg.view // FUSE-234: %[[C_1:.*]] = linalg.dim %arg2, 1 : !linalg.view -// FUSE-234: linalg.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { -// FUSE-234: linalg.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { -// FUSE-234: linalg.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { +// FUSE-234: loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { +// FUSE-234: loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { +// FUSE-234: loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { // FUSE-234: linalg.matmul // FUSE-234: linalg.matmul @@ -90,21 +90,21 @@ func @f4(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< } // No tiling => no fusion // FUSE-0-LABEL: func @f4 -// FUSE-0-NOT: linalg.for +// FUSE-0-NOT: loop.for // // Read to %D does not get tiled along 1st dimension => no fusion // FUSE-2-LABEL: func @f4 // FUSE-2: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) // FUSE-2: %[[C_0:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view -// FUSE-2: linalg.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { +// FUSE-2: loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { // FUSE-2: linalg.matmul // FUSE-2: linalg.matmul // // FUSE-23-LABEL: func @f4 // FUSE-23: %[[C_0:.*]] = linalg.dim %arg2, 0 : !linalg.view // FUSE-23: %[[D_1:.*]] = linalg.dim %arg3, 1 : !linalg.view -// FUSE-23: linalg.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { -// FUSE-23: linalg.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { +// FUSE-23: loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { +// FUSE-23: loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { // FUSE-23: linalg.matmul // FUSE-23: linalg.matmul // FUSE-23: linalg.matmul @@ -113,9 +113,9 @@ func @f4(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< // FUSE-234: %[[C_0:.*]] = linalg.dim %arg2, 0 : !linalg.view // FUSE-234: %[[C_1:.*]] = linalg.dim %arg2, 1 : !linalg.view // FUSE-234: %[[D_1:.*]] = linalg.dim %arg3, 1 : !linalg.view -// FUSE-234: linalg.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { -// FUSE-234: linalg.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { -// FUSE-234: linalg.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { +// FUSE-234: loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { +// FUSE-234: loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { +// FUSE-234: loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { // FUSE-234: linalg.matmul // FUSE-234: linalg.matmul // FUSE-234: linalg.matmul @@ -128,12 +128,12 @@ func @f5(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< } // No tiling => no fusion // FUSE-0-LABEL: func @f5 -// FUSE-0-NOT: linalg.for +// FUSE-0-NOT: loop.for // // FUSE-2-LABEL: func @f5 // FUSE-2: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) // FUSE-2: %[[D_0:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view -// FUSE-2: linalg.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { +// FUSE-2: loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { // FUSE-2: linalg.matmul // FUSE-2: linalg.matmul // @@ -141,8 +141,8 @@ func @f5(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< // FUSE-23: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) // FUSE-23: %[[D_0:.*]] = linalg.dim %arg3, 0 : !linalg.view // FUSE-23: %[[B_1:.*]] = linalg.dim %arg1, 1 : !linalg.view -// FUSE-23: linalg.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { -// FUSE-23: linalg.for %{{.*}} = %{{.*}} to %[[B_1]] step %{{.*}} { +// FUSE-23: loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { +// FUSE-23: loop.for %{{.*}} = %{{.*}} to %[[B_1]] step %{{.*}} { // FUSE-23: linalg.matmul // FUSE-23: linalg.matmul // @@ -151,9 +151,9 @@ func @f5(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< // FUSE-234: %[[D_0:.*]] = linalg.dim %arg3, 0 : !linalg.view // FUSE-234: %[[D_1:.*]] = linalg.dim %arg3, 1 : !linalg.view // FUSE-234: %[[B_1:.*]] = linalg.dim %arg1, 1 : !linalg.view -// FUSE-234: linalg.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { -// FUSE-234: linalg.for %{{.*}} = %{{.*}} to %[[B_1]] step %{{.*}} { -// FUSE-234: linalg.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { +// FUSE-234: loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { +// FUSE-234: loop.for %{{.*}} = %{{.*}} to %[[B_1]] step %{{.*}} { +// FUSE-234: loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { // FUSE-234: linalg.matmul // FUSE-234: linalg.matmul @@ -168,11 +168,11 @@ func @f6(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< // interleaved dependence. // No tiling => no fusion // FUSE-0-LABEL: func @f6 -// FUSE-0-NOT: linalg.for +// FUSE-0-NOT: loop.for // // Read to D is not tiled along 1st dimension => no fusion // FUSE-2-LABEL: func @f6 -// FUSE-2-NOT: linalg.for +// FUSE-2-NOT: loop.for // // FUSE-23-LABEL: func @f6 // @@ -189,18 +189,18 @@ func @f7(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< // immediately following read. // No tiling => no fusion // FUSE-0-LABEL: func @f7 -// FUSE-0-NOT: linalg.for +// FUSE-0-NOT: loop.for // // Read to %C (in 3rd matmul) is not tiled along 1st dimension => no fusion // FUSE-2-LABEL: func @f7 -// FUSE-2-NOT: linalg.for +// FUSE-2-NOT: loop.for // // FUSE-23-LABEL: func @f7 // FUSE-23: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) // FUSE-23: %[[A_0:.*]] = linalg.dim %arg0, 0 : !linalg.view // FUSE-23: %[[C_1:.*]] = linalg.dim %arg2, 1 : !linalg.view -// FUSE-23: linalg.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} { -// FUSE-23: linalg.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { +// FUSE-23: loop.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} { +// FUSE-23: loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { // FUSE-23: linalg.matmul // FUSE-23: linalg.matmul // FUSE-23: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) @@ -210,9 +210,9 @@ func @f7(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< // FUSE-234: %[[A_0:.*]] = linalg.dim %arg0, 0 : !linalg.view // FUSE-234: %[[A_1:.*]] = linalg.dim %arg0, 1 : !linalg.view // FUSE-234: %[[C_1:.*]] = linalg.dim %arg2, 1 : !linalg.view -// FUSE-234: linalg.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} { -// FUSE-234: linalg.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { -// FUSE-234: linalg.for %{{.*}} = %{{.*}} to %[[A_1]] step %{{.*}} { +// FUSE-234: loop.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} { +// FUSE-234: loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { +// FUSE-234: loop.for %{{.*}} = %{{.*}} to %[[A_1]] step %{{.*}} { // FUSE-234: linalg.matmul // FUSE-234: linalg.matmul // FUSE-234: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) @@ -226,13 +226,13 @@ func @f8(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< // In this example, %D can never be fused because the WAR on %C would be violated // No tiling => no fusion // FUSE-0-LABEL: func @f8 -// FUSE-0-NOT: linalg.for +// FUSE-0-NOT: loop.for // // FUSE-2-LABEL: func @f8 -// FUSE-2-NOT: linalg.for +// FUSE-2-NOT: loop.for // // FUSE-23-LABEL: func @f8 -// FUSE-23-NOT: linalg.for +// FUSE-23-NOT: loop.for // // FUSE-234-LABEL: func @f8 -// FUSE-234-NOT: linalg.for +// FUSE-234-NOT: loop.for diff --git a/mlir/test/Linalg/llvm.mlir b/mlir/test/Linalg/llvm.mlir index e10fe16ff0b2..8dd632d1c311 100644 --- a/mlir/test/Linalg/llvm.mlir +++ b/mlir/test/Linalg/llvm.mlir @@ -118,64 +118,6 @@ func @range_intersect(%arg0: !linalg.range, %arg1: !linalg.range) -> !linalg.ran // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ i64, i64, i64 }"> // CHECK: llvm.return %{{.*}} : !llvm<"{ i64, i64, i64 }"> -func @linalg_for(%arg0 : index, %arg1 : index, %arg2 : index) { - linalg.for %i0 = %arg0 to %arg1 step %arg2 { - %a = muli %i0, %arg0 : index - } - return -} -// CHECK-LABEL: func @linalg_for(%{{.*}}: !llvm.i64, %{{.*}}: !llvm.i64, %{{.*}}: !llvm.i64) { -// CHECK: llvm.br ^bb2(%{{.*}} : !llvm.i64) -// CHECK: ^bb1: // pred: ^bb2 -// CHECK: llvm.return -// CHECK: ^bb2(%{{.*}}: !llvm.i64): // 2 preds: ^bb0, ^bb3 -// CHECK: %{{.*}} = llvm.icmp "sgt" %{{.*}}, %{{.*}} : !llvm.i64 -// CHECK: llvm.cond_br %{{.*}}, ^bb3, ^bb1 -// CHECK: ^bb3: // pred: ^bb2 -// CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} : !llvm.i64 -// CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} : !llvm.i64 -// CHECK: llvm.br ^bb2(%{{.*}} : !llvm.i64) - -func @linalg_for_2(%arg0 : index, %arg1 : index, %arg2 : index) { - linalg.for %i0 = %arg0 to %arg1 step %arg2 { - linalg.for %i1 = %arg0 to %arg1 step %arg2 { - %a = muli %i0, %i1 : index - } - linalg.for %i2 = %arg0 to %arg1 step %arg2 { - %b = muli %i0, %i2 : index - } - } - return -} -// CHECK-LABEL: func @linalg_for_2(%{{.*}}: !llvm.i64, %{{.*}}: !llvm.i64, %{{.*}}: !llvm.i64) { -// CHECK: llvm.br ^bb2(%{{.*}} : !llvm.i64) -// CHECK: ^bb1: // pred: ^bb2 -// CHECK: llvm.return -// CHECK: ^bb2(%{{.*}}: !llvm.i64): // 2 preds: ^bb0, ^bb5 -// CHECK: %{{.*}} = llvm.icmp "sgt" %{{.*}}, %{{.*}} : !llvm.i64 -// CHECK: llvm.cond_br %{{.*}}, ^bb3, ^bb1 -// CHECK: ^bb3: // pred: ^bb2 -// CHECK: llvm.br ^bb8(%{{.*}} : !llvm.i64) -// CHECK: ^bb4: // pred: ^bb8 -// CHECK: llvm.br ^bb6(%{{.*}} : !llvm.i64) -// CHECK: ^bb5: // pred: ^bb6 -// CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} : !llvm.i64 -// CHECK: llvm.br ^bb2(%{{.*}} : !llvm.i64) -// CHECK: ^bb6(%{{.*}}: !llvm.i64): // 2 preds: ^bb4, ^bb7 -// CHECK: %{{.*}} = llvm.icmp "sgt" %{{.*}}, %{{.*}} : !llvm.i64 -// CHECK: llvm.cond_br %{{.*}}, ^bb7, ^bb5 -// CHECK: ^bb7: // pred: ^bb6 -// CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} : !llvm.i64 -// CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} : !llvm.i64 -// CHECK: llvm.br ^bb6(%{{.*}} : !llvm.i64) -// CHECK: ^bb8(%{{.*}}: !llvm.i64): // 2 preds: ^bb3, ^bb9 -// CHECK: %{{.*}} = llvm.icmp "sgt" %{{.*}}, %{{.*}} : !llvm.i64 -// CHECK: llvm.cond_br %{{.*}}, ^bb9, ^bb4 -// CHECK: ^bb9: // pred: ^bb8 -// CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} : !llvm.i64 -// CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} : !llvm.i64 -// CHECK: llvm.br ^bb8(%{{.*}} : !llvm.i64) - func @subview(%arg0: !linalg.view) { %c0 = constant 0 : index %0 = linalg.subview %arg0[%c0, %c0, %c0, %c0, %c0, %c0] : !linalg.view diff --git a/mlir/test/Linalg/loops.mlir b/mlir/test/Linalg/loops.mlir index c995c5eb5f04..a12aa9917595 100644 --- a/mlir/test/Linalg/loops.mlir +++ b/mlir/test/Linalg/loops.mlir @@ -23,9 +23,9 @@ func @matmul(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: in // CHECK: %[[M:.*]] = linalg.dim %[[A]], 0 : !linalg.view // CHECK: %[[K:.*]] = linalg.dim %[[A]], 1 : !linalg.view // CHECK: %[[N:.*]] = linalg.dim %[[B]], 1 : !linalg.view -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[M]] step %{{.*}} { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[M]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { // CHECK-DAG: %[[a:.*]] = linalg.load %[[A]][%{{.*}}, %{{.*}}] : !linalg.view // CHECK-DAG: %[[b:.*]] = linalg.load %[[B]][%{{.*}}, %{{.*}}] : !linalg.view // CHECK-DAG: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32 @@ -50,8 +50,8 @@ func @matvec(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: in // CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> !linalg.view // CHECK: %[[M:.*]] = linalg.dim %[[A]], 0 : !linalg.view // CHECK: %[[K:.*]] = linalg.dim %[[A]], 1 : !linalg.view -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[M]] step %{{.*}} { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[M]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { // CHECK-DAG: %[[a:.*]] = linalg.load %[[A]][%{{.*}}, %{{.*}}] : !linalg.view // CHECK-DAG: %[[b:.*]] = linalg.load %[[B]][%{{.*}}] : !linalg.view // CHECK-DAG: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32 @@ -74,7 +74,7 @@ func @dot(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index // CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> !linalg.view // CHECK: %[[C:.*]] = linalg.view %arg0[] : !linalg.buffer -> !linalg.view // CHECK: %[[K:.*]] = linalg.dim %[[A]], 0 : !linalg.view -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { // CHECK-DAG: %[[a:.*]] = linalg.load %[[A]][%{{.*}}] : !linalg.view // CHECK-DAG: %[[b:.*]] = linalg.load %[[B]][%{{.*}}] : !linalg.view // CHECK-DAG: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32 @@ -88,7 +88,7 @@ func @dot_view(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !l } // CHECK-LABEL: func @dot_view(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { // CHECK: %[[K:.*]] = linalg.dim %arg0, 0 : !linalg.view -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { // CHECK-DAG: %[[a:.*]] = linalg.load %arg0[%{{.*}}] : !linalg.view // CHECK-DAG: %[[b:.*]] = linalg.load %{{.*}}[%{{.*}}] : !linalg.view // CHECK-DAG: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32 @@ -101,7 +101,7 @@ func @fill_view(%arg0: !linalg.view, %arg1: f32) { return } // CHECK-LABEL: func @fill_view(%{{.*}}: !linalg.view, %{{.*}}: f32) { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK: linalg.store %{{.*}}, %{{.*}}[%{{.*}}] : !linalg.view func @fill_view0(%arg0: !linalg.view, %arg1: f32) { @@ -116,9 +116,9 @@ func @fill_view3(%arg0: !linalg.view, %arg1: f32) { return } // CHECK-LABEL: func @fill_view3(%{{.*}}: !linalg.view, %{{.*}}: f32) { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK: linalg.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] : !linalg.view func @copy_view(%arg0: !linalg.view, %arg1: !linalg.view) { @@ -126,7 +126,7 @@ func @copy_view(%arg0: !linalg.view, %arg1: !linalg.view) { return } // CHECK-LABEL: func @copy_view(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK: %[[L:.*]] = linalg.load %{{.*}}[%{{.*}}] : !linalg.view // CHECK: linalg.store %[[L]], %{{.*}}[%{{.*}}] : !linalg.view @@ -145,9 +145,9 @@ func @copy_view3(%arg0: !linalg.view, %arg1: !linalg.view) return } // CHECK-LABEL: func @copy_view3(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK: %[[L:.*]] = linalg.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] : !linalg.view // CHECK: linalg.store %[[L]], %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] : !linalg.view @@ -161,11 +161,11 @@ func @conv_view3(%arg0: !linalg.view, %arg1: !linalg.view, // CHECK: %[[K:.*]] = linalg.dim %arg0, 2 : !linalg.view // CHECK: %[[B:.*]] = linalg.dim %arg1, 0 : !linalg.view // CHECK: %[[X0:.*]] = linalg.dim %arg2, 1 : !linalg.view -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[B]] step %{{.*}} { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[X0]] step %{{.*}} { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[Q]] step %{{.*}} { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[Z0]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[B]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[X0]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[Q]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[Z0]] step %{{.*}} { // CHECK: %[[SUM:.*]] = affine.apply #[[S2D1]](%{{.*}}, %{{.*}}) // CHECK: %{{.*}} = linalg.load %{{.*}}[%{{.*}}, %[[SUM]], %{{.*}}] : !linalg.view // CHECK: %{{.*}} = linalg.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] : !linalg.view @@ -186,13 +186,13 @@ func @conv_view4(%arg0: !linalg.view, %arg1: !linalg.view // CHECK: %[[X0:.*]] = linalg.dim %arg2, 1 : !linalg.view // CHECK: %[[X1:.*]] = linalg.dim %arg2, 2 : !linalg.view -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[B]] step %{{.*}} { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[X0]] step %{{.*}} { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[X1]] step %{{.*}} { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[Q]] step %{{.*}} { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[Z0]] step %{{.*}} { -// CHECK: linalg.for %{{.*}} = %{{.*}} to %[[Z1]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[B]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[X0]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[X1]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[Q]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[Z0]] step %{{.*}} { +// CHECK: loop.for %{{.*}} = %{{.*}} to %[[Z1]] step %{{.*}} { // CHECK: %[[SUM0:.*]] = affine.apply #map1(%{{.*}}, %{{.*}}) // CHECK: %[[SUM1:.*]] = affine.apply #map2(%{{.*}}, %{{.*}}) // CHECK: %{{.*}} = linalg.load %{{.*}}[%{{.*}}, %[[SUM0]], %[[SUM1]], %{{.*}}] : !linalg.view diff --git a/mlir/test/Linalg/promote.mlir b/mlir/test/Linalg/promote.mlir index 23d51ec770e8..611f1aa126fb 100644 --- a/mlir/test/Linalg/promote.mlir +++ b/mlir/test/Linalg/promote.mlir @@ -13,9 +13,9 @@ func @matmul(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: in return } // TILE-1D-LABEL: func @matmul(%{{.*}}: !linalg.buffer, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { -// TILE-1D: linalg.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { -// TILE-1D: linalg.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { -// TILE-1D: linalg.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// TILE-1D: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// TILE-1D: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// TILE-1D: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // TILE-1D: %[[vA:.*]] = linalg.subview {{.*}} : !linalg.view // TILE-1D: %[[vB:.*]] = linalg.subview {{.*}} : !linalg.view // TILE-1D: %[[vC:.*]] = linalg.subview {{.*}} : !linalg.view diff --git a/mlir/test/Linalg/roundtrip.mlir b/mlir/test/Linalg/roundtrip.mlir index becdf53dc46e..2a3a3c5bd356 100644 --- a/mlir/test/Linalg/roundtrip.mlir +++ b/mlir/test/Linalg/roundtrip.mlir @@ -84,26 +84,26 @@ func @range_intersect(%arg0: !linalg.range, %arg1: !linalg.range) -> !linalg.ran // CHECK-NEXT: return %{{.*}} : !linalg.range func @linalg_for(%arg0 : index, %arg1 : index, %arg2 : index) { - linalg.for %i0 = %arg0 to %arg1 step %arg2 { - linalg.for %i1 = %arg0 to %arg1 step %arg2 { + loop.for %i0 = %arg0 to %arg1 step %arg2 { + loop.for %i1 = %arg0 to %arg1 step %arg2 { %min_cmp = cmpi "slt", %i0, %i1 : index %min = select %min_cmp, %i0, %i1 : index %max_cmp = cmpi "sge", %i0, %i1 : index %max = select %max_cmp, %i0, %i1 : index - linalg.for %i2 = %min to %max step %i1 { + loop.for %i2 = %min to %max step %i1 { } } } return } // CHECK-LABEL: func @linalg_for(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { -// CHECK-NEXT: linalg.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { -// CHECK-NEXT: linalg.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK-NEXT: %{{.*}} = cmpi "slt", %{{.*}}, %{{.*}} : index // CHECK-NEXT: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : index // CHECK-NEXT: %{{.*}} = cmpi "sge", %{{.*}}, %{{.*}} : index // CHECK-NEXT: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : index -// CHECK-NEXT: linalg.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { func @fill_view(%arg0: !linalg.view, %arg1: f32) { linalg.fill(%arg0, %arg1) : !linalg.view, f32 diff --git a/mlir/test/Linalg/tile.mlir b/mlir/test/Linalg/tile.mlir index b65ba390a119..92898b77be7f 100644 --- a/mlir/test/Linalg/tile.mlir +++ b/mlir/test/Linalg/tile.mlir @@ -16,7 +16,7 @@ func @matmul(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: } // TILE-2-LABEL: func @matmul(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { // TILE-2: %[[M:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view -// TILE-2: linalg.for %{{.*}} = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { +// TILE-2: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { // TILE-2: %[[a:.*]] = affine.apply #[[UB0]](%{{.*}}) // TILE-2: %[[K:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view // TILE-2: %[[sAi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[a]], %{{.*}}, %{{.*}}, %[[K]], %{{.*}}] : !linalg.view @@ -27,7 +27,7 @@ func @matmul(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: // TILE-02-LABEL: func @matmul(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { // TILE-02: %[[N:.*]] = linalg.dim %arg1, 1 : !linalg.view -// TILE-02: linalg.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} { +// TILE-02: loop.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} { // TILE-02: %[[K:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view // TILE-02: %[[b:.*]] = affine.apply #[[UB0]](%{{.*}}) // TILE-02: %[[sBj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[K]], %{{.*}}, %{{.*}}, %[[b]], %{{.*}}] : !linalg.view @@ -38,7 +38,7 @@ func @matmul(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: // TILE-002-LABEL: func @matmul(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { // TILE-002: %[[K:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view -// TILE-002: linalg.for %{{.*}} = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { +// TILE-002: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { // TILE-002: %[[M:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view // TILE-002: %[[a:.*]] = affine.apply #[[UB0]](%{{.*}}) // TILE-002: %[[sAj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[M]], %{{.*}}, %{{.*}}, %[[a]], %{{.*}}] : !linalg.view @@ -51,9 +51,9 @@ func @matmul(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: // TILE-234: %[[M:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view // TILE-234: %[[K:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view // TILE-234: %[[N:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view -// TILE-234: linalg.for %{{.*}} = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { -// TILE-234: linalg.for %{{.*}} = %{{.*}}{{.*}} to %[[N]] step %{{.*}} { -// TILE-234: linalg.for %{{.*}} = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { +// TILE-234: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { +// TILE-234: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[N]] step %{{.*}} { +// TILE-234: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { // TILE-234: %[[ai:.*]] = affine.apply #[[UB0]](%{{.*}}) // TILE-234: %[[ak:.*]] = affine.apply #[[UB2]](%{{.*}}) // TILE-234: %[[sAik:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[ai]], %{{.*}}, %{{.*}}, %[[ak]], %{{.*}}] : !linalg.view @@ -72,7 +72,7 @@ func @matvec(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !l } // TILE-2-LABEL: func @matvec(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { // TILE-2: %[[M:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view -// TILE-2: linalg.for %{{.*}} = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { +// TILE-2: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { // TILE-2: %[[a:.*]] = affine.apply #[[UB0]](%{{.*}}) // TILE-2: %[[N:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view // TILE-2: %[[sAi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[a]], %{{.*}}, %{{.*}}, %[[N]], %{{.*}}] : !linalg.view @@ -82,7 +82,7 @@ func @matvec(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !l // TILE-02-LABEL: func @matvec(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { // TILE-02: %[[K:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view -// TILE-02: linalg.for %{{.*}} = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { +// TILE-02: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { // TILE-02: %[[M:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view // TILE-02: %[[a:.*]] = affine.apply #[[UB0]](%{{.*}}) // TILE-02: %[[sAj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[M]], %{{.*}}, %{{.*}}, %[[a]], %{{.*}}] : !linalg.view @@ -91,13 +91,13 @@ func @matvec(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !l // TILE-02: linalg.matvec(%[[sAj]], %[[sBj]], %{{.*}}) : !linalg.view, !linalg.view, !linalg.view // TILE-002-LABEL: func @matvec(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// TILE-002-NOT: linalg.for +// TILE-002-NOT: loop.for // TILE-234-LABEL: func @matvec(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { // TILE-234: %[[M:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view // TILE-234: %[[K:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view -// TILE-234: linalg.for %{{.*}} = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { -// TILE-234: linalg.for %{{.*}} = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { +// TILE-234: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { +// TILE-234: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { // TILE-234: %[[ai:.*]] = affine.apply #[[UB0]](%{{.*}}) // TILE-234: %[[aj:.*]] = affine.apply #[[UB1]](%{{.*}}) // TILE-234: %[[sAij:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[ai]], %{{.*}}, %{{.*}}, %[[aj]], %{{.*}}] : !linalg.view @@ -114,7 +114,7 @@ func @dot(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg } // TILE-2-LABEL: func @dot(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { // TILE-2: %[[M:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view -// TILE-2: linalg.for %{{.*}} = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { +// TILE-2: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { // TILE-2: %[[a:.*]] = affine.apply #[[UB0]](%{{.*}}) // TILE-2: %[[sAi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[a]], %{{.*}}] : !linalg.view // TILE-2: %[[b:.*]] = affine.apply #[[UB0]](%{{.*}}) @@ -122,14 +122,14 @@ func @dot(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg // TILE-2: linalg.dot(%[[sAi]], %[[sBi]], {{.*}}) : !linalg.view, !linalg.view, !linalg.view // TILE-02-LABEL: func @dot(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// TILE-02-NOT: linalg.for +// TILE-02-NOT: loop.for // TILE-002-LABEL: func @dot(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// TILE-002-NOT: linalg.for +// TILE-002-NOT: loop.for // TILE-234-LABEL: func @dot(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { // TILE-234: %[[K:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view -// TILE-234: linalg.for %{{.*}} = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { +// TILE-234: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { // TILE-234: %[[a:.*]] = affine.apply #[[UB0]](%{{.*}}) // TILE-234: %[[sAi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[a]], %{{.*}}] : !linalg.view // TILE-234: %[[b:.*]] = affine.apply #[[UB0]](%{{.*}}) diff --git a/mlir/test/Linalg/tile_conv.mlir b/mlir/test/Linalg/tile_conv.mlir index 6d85556a7d79..5fdb28df5d92 100644 --- a/mlir/test/Linalg/tile_conv.mlir +++ b/mlir/test/Linalg/tile_conv.mlir @@ -14,9 +14,9 @@ func @conv(%arg0: !linalg.view, %arg1: !linalg.view, % // TILE-23004: %[[B:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view // TILE-23004: %[[PaddedInput0:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view // TILE-23004: %[[X0:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view -// TILE-23004: linalg.for %{{.*}} = %{{.*}} to %[[B]] step %{{.*}} { -// TILE-23004: linalg.for %{{.*}} = %{{.*}} to %[[X0]] step %{{.*}} { -// TILE-23004: linalg.for %{{.*}} = %{{.*}} to %[[Q]] step %{{.*}} { +// TILE-23004: loop.for %{{.*}} = %{{.*}} to %[[B]] step %{{.*}} { +// TILE-23004: loop.for %{{.*}} = %{{.*}} to %[[X0]] step %{{.*}} { +// TILE-23004: loop.for %{{.*}} = %{{.*}} to %[[Q]] step %{{.*}} { // TILE-23004: %[[Z0:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view // TILE-23004: %[[Z1:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view // TILE-23004: %[[I2p4:.*]] = affine.apply #[[UB2]](%{{.*}}) diff --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir index 575f81d5cf5d..2dc974883c6f 100644 --- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir +++ b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir @@ -9,7 +9,7 @@ func @fill_f32(%arg0 : !linalg.buffer, %f : f32) { %s = linalg.buffer_size %arg0 : !linalg.buffer %R = linalg.range %c0:%s:%c1 : !linalg.range %V = linalg.view %arg0[%R] : !linalg.buffer -> !linalg.view - linalg.for %i0 = %c0 to %s step %c1 { + loop.for %i0 = %c0 to %s step %c1 { linalg.store %f, %V[%i0] : !linalg.view } return