diff --git a/mlir/tutorial/Linalg2/include/linalg2/TensorOps.h b/mlir/tutorial/Linalg2/include/linalg2/TensorOps.h index c20a91681654..406bcaacce25 100644 --- a/mlir/tutorial/Linalg2/include/linalg2/TensorOps.h +++ b/mlir/tutorial/Linalg2/include/linalg2/TensorOps.h @@ -41,6 +41,7 @@ protected: static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); void print(mlir::OpAsmPrinter *p); +public: ////////////////////////////////////////////////////////////////////////////// // Op-specific functionality. ////////////////////////////////////////////////////////////////////////////// @@ -48,7 +49,6 @@ protected: mlir::Operation::operand_range getInputs(); mlir::Operation::operand_range getOutputs(); -public: /// These are better as methods calling into the ConcreteOp instead of /// template parameters because methods allow more generic behavior and avoid /// specializing for number of arguments. All derived classes have @@ -72,14 +72,24 @@ public: ////////////////////////////////////////////////////////////////////////////// mlir::Value *getInputView(unsigned i); mlir::Value *getOutputView(unsigned i); - /// Computes a mapping from all the ranges of the operands to the enclosing - /// loops. In order to support "broadcast"-style semantics, we need to - /// consider all the operands (i.e. input operands are not sufficient). + + /// Each op is responsible for declaring how it lowers itself to scalar form, + /// given the enclosing parallel and reduction induction variables. + /// `emitScalarImplementation` emits the scalar IR for the op in the nesting + /// context of the innermost enclosing loop(i.e. `reductionIvs.back()` or + /// `parallel.back()`). + void emitScalarImplementation(llvm::ArrayRef parallelIvs, + llvm::ArrayRef reductionIvs); + + /// Represents a mapping from the loops to all the ranges of the operands. /// The operands and their ranges are in the order defined by the particular /// ConcreteOp implementation, the resulting map must match those. - /// This is currently computed but can also be specified explicitly in each - /// operator to generalize to cases where an analysis is not available. - mlir::AffineMap operandRangesToLoopsMap(); + /// In favorable cases, this can be calculated by an analysis but specifying + /// it explicitly is not expensive and generalizes to cases where an analysis + /// is not available. + /// For details, see the description of loopsToOperandRangesMap in each + /// ConcreteOp + mlir::AffineMap loopsToOperandRangesMap(); }; /// Implements c = A * B where c is a scalar and A and B are 1-D vectors. @@ -119,6 +129,27 @@ public: /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a /// loop over matvec). Does nothing by default. void writeAsFinerGrainTensorContraction(); + + /// Inputs to this map will be (%k) coming from enclosing loops. + /// Therefore, the mapping to get back to A(K), B(K), C() is: + /// (d0) -> (d0, d0)(%k) + /// And the operands ranges are: + /// (%k, %k) + mlir::AffineMap loopsToOperandRangesMap(); + + /// Given an enclosing reduction loop with iv `r_i`, emits MLIR corresponding + /// to: + /// 1. conditionally assign scalarC to 0.0f on the first iteration or load + /// C[] from memory (0-D tensor) + /// 2. multiply A[r_i] by B[r_i] and add to scalarC + /// 3. store back scalarC at C[] + /// + /// In some compact index notation this could be written: + /// cond = (r_i == zero) + /// scalarC = select(cond, zerof, C[]); + /// C[] = scalarC + A[r_i] * B[r_i]; + void emitScalarImplementation(llvm::ArrayRef parallelIvs, + llvm::ArrayRef reductionIvs); }; /// Implements C = A * B where A is a 2-D matrix and X and Y are 1-D vectors. @@ -158,6 +189,27 @@ public: /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a /// loop over matvec). Does nothing by default. void writeAsFinerGrainTensorContraction(); + + /// Inputs to this map will be (%m, %k) coming from enclosing loops. + /// Therefore, the mapping to get back to A(M, K), B(K), C(M) is: + /// (d0, d1) -> (d0, d1, d1, d0)(%m, %k) + /// And the operands ranges are: + /// (%m, %k, %k, %m) + mlir::AffineMap loopsToOperandRangesMap(); + + /// Given an enclosing parallel loop with iv `i` and an enclosing parallel + /// loop with iv `r_j`, emits MLIR corresponding to: + /// 1. conditionally assign scalarC to 0.0f on the first iteration or load + /// C[i] + /// 2. multiply A[i, r_j] by B[r_j] and add to scalarC + /// 3. store back scalarC at C[i] + /// + /// In some compact index notation this could be written: + /// cond = (r_j == zero) + /// scalarC = select(cond, zerof, C(i)); + /// C(i) = scalarC + A(i, r_j) * B(r_j); + void emitScalarImplementation(llvm::ArrayRef parallelIvs, + llvm::ArrayRef reductionIvs); }; /// Implements C = A * B on 2-D matrices. @@ -197,6 +249,27 @@ public: /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a /// loop over matvec). Does nothing by default. void writeAsFinerGrainTensorContraction(); + + /// Inputs to this map will be (%m, %n, %k) coming from enclosing loops. + /// Therefore, the mapping to get back to A(M, K), B(K, N), C(M, N) is: + /// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k) + /// And the operands ranges are: + /// (%m, %k, %k, %n, %m, %n) + mlir::AffineMap loopsToOperandRangesMap(); + + /// Given a enclosing parallel loops with ivs `i` and `j`, and an enclosing + /// reduction loop with iv `r_k`, emits MLIR corresponding to: + /// 1. conditionally assign scalarC to 0.0f on the first iteration or load + /// C[i, j] + /// 2. multiply A[i, r_k] by B[r_k, j] and add to scalarC + /// 3. store back scalarC at C[i, j] + /// + /// In some compact index notation this could be written: + /// cond = (r_k == zero) + /// scalarC = select(cond, zerof, C[i, j]); + /// C[i, j] = scalarC + A[i, r_k] * B[r_k, j]; + void emitScalarImplementation(llvm::ArrayRef parallelIvs, + llvm::ArrayRef reductionIvs); }; } // namespace linalg diff --git a/mlir/tutorial/Linalg3/Example.cpp b/mlir/tutorial/Linalg3/Example.cpp index 1c10fd5d81bb..13eadd4fc8b5 100644 --- a/mlir/tutorial/Linalg3/Example.cpp +++ b/mlir/tutorial/Linalg3/Example.cpp @@ -64,13 +64,15 @@ TEST_FUNC(matmul_as_matvec) { Module module(&context); mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec"); lowerToFinerGrainedTensorContraction(f); + composeSliceOps(f); // clang-format off // CHECK-LABEL: func @matmul_as_matvec(%arg0: memref, %arg1: memref, %arg2: memref) { // CHECK: %[[N:.*]] = dim %arg2, 1 : memref + // CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view"> // CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) { - // CHECK-NEXT: %[[vB:.*]] = linalg.slice %{{.*}}[*, %i0] { dim : 1 } : !linalg<"view"> - // CHECK-NEXT: %[[vC:.*]] = linalg.slice %{{.*}}[*, %i0] { dim : 1 } : !linalg<"view"> - // CHECK-NEXT: linalg.matvec {%{{.*}}, %[[vB]]} -> {%[[vC]]} + // CHECK: %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view"> + // CHECK: %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view"> + // CHECK: linalg.matvec {%[[vA]], %[[vB]]} -> {%[[vC]]} // clang-format on cleanupAndPrintFunction(f); } @@ -81,21 +83,84 @@ TEST_FUNC(matmul_as_dot) { mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_dot"); lowerToFinerGrainedTensorContraction(f); lowerToFinerGrainedTensorContraction(f); + composeSliceOps(f); // clang-format off // CHECK-LABEL: func @matmul_as_dot(%arg0: memref, %arg1: memref, %arg2: memref) { // CHECK: %[[M:.*]] = dim %arg0, 0 : memref // CHECK: %[[N:.*]] = dim %arg2, 1 : memref // CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) { - // CHECK-NEXT: %[[vB:.*]] = linalg.slice {{.*}}[*, %i0] { dim : 1 } : !linalg<"view"> - // CHECK-NEXT: %[[sC:.*]] = linalg.slice {{.*}}[*, %i0] { dim : 1 } : !linalg<"view"> + // CHECK: %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view"> // CHECK-NEXT: affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) { - // CHECK-NEXT: %[[vA:.*]] = linalg.slice {{.*}}[%i1, *] { dim : 0 } : !linalg<"view"> - // CHECK-NEXT: %[[vC:.*]] = linalg.slice %[[sC]][%i1] { dim : 0 } : !linalg<"view<0xf32>"> + // CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view"> + // CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view<0xf32>"> // CHECK-NEXT: linalg.dot {%[[vA]], %[[vB]]} -> {%[[vC]]} // clang-format on cleanupAndPrintFunction(f); } +TEST_FUNC(matmul_as_loops) { + MLIRContext context; + Module module(&context); + mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); + lowerToLoops(f); + composeSliceOps(f); + // clang-format off + // CHECK-LABEL: func @matmul_as_loops(%arg0: memref, %arg1: memref, %arg2: memref) { + // CHECK: %[[M:.*]] = dim %arg0, 0 : memref + // CHECK: %[[N:.*]] = dim %arg2, 1 : memref + // CHECK: %[[K:.*]] = dim %arg0, 1 : memref + // CHECK: %[[rM:.*]] = linalg.range %c0:%[[M]]:%c1 : !linalg<"range"> + // CHECK: %[[rN:.*]] = linalg.range %c0:%[[N]]:%c1 : !linalg<"range"> + // CHECK: %[[rK:.*]] = linalg.range %c0:%[[K]]:%c1 : !linalg<"range"> + // CHECK: %[[vA:.*]] = linalg.view %arg0[%[[rM]], %[[rK]]] : !linalg<"view"> + // CHECK: %[[vB:.*]] = linalg.view %arg1[%[[rK]], %[[rN]]] : !linalg<"view"> + // CHECK: %[[vC:.*]] = linalg.view %arg2[%[[rM]], %[[rN]]] : !linalg<"view"> + // CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) { + // CHECK: affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) { + // CHECK: affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) { + // CHECK: %{{.*}} = cmpi "eq", %{{.*}} : index + // CHECK: %{{.*}} = linalg.load %[[vC]][%i0, %i1] : !linalg<"view"> + // CHECK: %{{.*}} = select {{.*}} : f32 + // CHECK: %{{.*}} = linalg.load %[[vB]][%i2, %i1] : !linalg<"view"> + // CHECK: %{{.*}} = linalg.load %[[vA]][%i0, %i2] : !linalg<"view"> + // CHECK: %{{.*}} = mulf {{.*}} : f32 + // CHECK: %{{.*}} = addf {{.*}} : f32 + // CHECK: linalg.store {{.*}}[%i0, %i1] : !linalg<"view"> + // clang-format on + cleanupAndPrintFunction(f); +} + +TEST_FUNC(matmul_as_matvec_as_loops) { + MLIRContext context; + Module module(&context); + mlir::Function *f = + makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_loops"); + lowerToFinerGrainedTensorContraction(f); + lowerToLoops(f); + composeSliceOps(f); + // clang-format off + // CHECK-LABEL: func @matmul_as_matvec_as_loops(%arg0: memref, %arg1: memref, %arg2: memref) { + // CHECK: %[[M:.*]] = dim %arg0, 0 : memref + // CHECK: %[[N:.*]] = dim %arg2, 1 : memref + // CHECK: %[[K:.*]] = dim %arg0, 1 : memref + // CHECK: %[[vA:.*]] = linalg.view %arg0[{{.*}}, {{.*}}] : !linalg<"view"> + // CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) { + // CHECK: %[[vB:.*]] = linalg.view %arg1[{{.*}}, {{.*}}] : !linalg<"view"> + // CHECK: %[[vC:.*]] = linalg.view %arg2[{{.*}}, {{.*}}] : !linalg<"view"> + // CHECK: affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) { + // CHECK: affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) { + // CHECK: %{{.*}} = cmpi "eq", %i2, %{{.*}} : index + // CHECK: %[[C:.*]] = linalg.load %[[vC]][%i1] : !linalg<"view"> + // CHECK: %[[C2:.*]] = select %{{.*}}, %{{.*}}, %[[C]] : f32 + // CHECK: %[[B:.*]] = linalg.load %[[vB]][%i2] : !linalg<"view"> + // CHECK: %[[A:.*]] = linalg.load %[[vA]][%i1, %i2] : !linalg<"view"> + // CHECK: %{{.*}} = mulf %[[A]], %[[B]] : f32 + // CHECK: %{{.*}} = addf %[[C2]], %{{.*}} : f32 + // CHECK: linalg.store %{{.*}}, %{{.*}}[%i1] : !linalg<"view"> + // clang-format on + cleanupAndPrintFunction(f); +} + int main() { RUN_TESTS(); return 0; diff --git a/mlir/tutorial/Linalg3/include/linalg3/Analysis.h b/mlir/tutorial/Linalg3/include/linalg3/Analysis.h new file mode 100644 index 000000000000..813fc37b73e0 --- /dev/null +++ b/mlir/tutorial/Linalg3/include/linalg3/Analysis.h @@ -0,0 +1,37 @@ +//===- Analysis.h - Linalg dialect Analysis function definitions ----------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG3_ANALYSIS_H_ +#define LINALG3_ANALYSIS_H_ + +#include "linalg2/Analysis.h" + +namespace mlir { +class AffineMap; +} // namespace mlir + +namespace linalg { + +/// Given a `map` specification and a subset of its results +/// `[beginResult, endResult)`, returns the inverse map that maps result +/// positions to dim positions. +mlir::AffineMap inverseSubMap(mlir::AffineMap map, unsigned beginResult = 0, + unsigned endResult = 0); + +} // namespace linalg + +#endif // LINALG3_ANALYSIS_H_ diff --git a/mlir/tutorial/Linalg3/include/linalg3/Intrinsics.h b/mlir/tutorial/Linalg3/include/linalg3/Intrinsics.h new file mode 100644 index 000000000000..75a04178162f --- /dev/null +++ b/mlir/tutorial/Linalg3/include/linalg3/Intrinsics.h @@ -0,0 +1,31 @@ +//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG3_INTRINSICS_H_ +#define LINALG3_INTRINSICS_H_ + +#include "linalg2/Intrinsics.h" +#include "linalg3/Ops.h" + +namespace linalg { +namespace intrinsics { +using load = mlir::edsc::intrinsics::ValueBuilder; +using store = mlir::edsc::intrinsics::OperationBuilder; +} // namespace intrinsics +} // namespace linalg + +#endif // LINALG3_INTRINSICS_H_ diff --git a/mlir/tutorial/Linalg3/include/linalg3/LoadStoreOps.h b/mlir/tutorial/Linalg3/include/linalg3/LoadStoreOps.h new file mode 100644 index 000000000000..b77e7028ea41 --- /dev/null +++ b/mlir/tutorial/Linalg3/include/linalg3/LoadStoreOps.h @@ -0,0 +1,89 @@ +//===- LoadStoreOps.h - Linalg dialect Load/Store operation definitions ---===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG3_LOADSTOREOP_H_ +#define LINALG3_LOADSTOREOP_H_ + +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +namespace linalg { + +class ViewType; + +/// A linalg.LoadOp is the counterpart of affine.load but operating on ViewType +/// instead of MemRefType. +class LoadOp : public mlir::Op { +public: + using Op::Op; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.load"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *view, + mlir::ArrayRef indices = {}); + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + unsigned getRank(); + ViewType getViewType(); + mlir::Value *getView() { return getOperand(0); } + mlir::Operation::operand_range getIndices() { + return {operand_begin() + 1, operand_end()}; + } +}; + +/// A linalg.StoreOp is the counterpart of affine.store but operating on +/// ViewType instead of MemRefType. +class StoreOp : public mlir::Op { +public: + using Op::Op; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.store"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *valueToStore, mlir::Value *view, + mlir::ArrayRef indices = {}); + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + unsigned getRank(); + ViewType getViewType(); + mlir::Value *getValueToStore() { return getOperand(0); } + mlir::Value *getView() { return getOperand(1); } + mlir::Operation::operand_range getIndices() { + return {operand_begin() + 2, operand_end()}; + } +}; + +} // namespace linalg + +#endif // LINALG3_LOADSTOREOP_H_ diff --git a/mlir/tutorial/Linalg3/include/linalg3/Ops.h b/mlir/tutorial/Linalg3/include/linalg3/Ops.h index f2d5ec453d4c..813cbff74af8 100644 --- a/mlir/tutorial/Linalg3/include/linalg3/Ops.h +++ b/mlir/tutorial/Linalg3/include/linalg3/Ops.h @@ -19,6 +19,7 @@ #define LINALG3_OPS_H_ #include "linalg2/Ops.h" +#include "linalg3/LoadStoreOps.h" #include "linalg3/TensorOps.h" #endif // LINALG3_OPS_H_ diff --git a/mlir/tutorial/Linalg3/include/linalg3/TensorOps-inl.h b/mlir/tutorial/Linalg3/include/linalg3/TensorOps-inl.h index c4082d5b2919..60d99abed850 100644 --- a/mlir/tutorial/Linalg3/include/linalg3/TensorOps-inl.h +++ b/mlir/tutorial/Linalg3/include/linalg3/TensorOps-inl.h @@ -22,9 +22,10 @@ #define LINALG3_TENSOROPS_INL_H_ #include "linalg1/Common.h" +#include "linalg1/Utils.h" #include "linalg2/TensorOps.h" - -namespace linalg { +#include "linalg3/Analysis.h" +#include "linalg3/Ops.h" template mlir::Value * @@ -38,6 +39,90 @@ linalg::TensorContractionBase::getOutputView(unsigned i) { return *(getOutputs().begin() + i); } -} // namespace linalg +template +mlir::AffineMap +linalg::TensorContractionBase::loopsToOperandRangesMap() { + return static_cast(this)->loopsToOperandRangesMap(); +} -#endif // LINALG3_TENSOROPS-INL_H_ +template +void linalg::TensorContractionBase::emitScalarImplementation( + llvm::ArrayRef parallelIvs, + llvm::ArrayRef reductionIvs) { + static_cast(this)->emitScalarImplementation(parallelIvs, + reductionIvs); +} + +template +mlir::AffineMap linalg::operandRangesToLoopsMap( + linalg::TensorContractionBase &tensorContraction) { + return inverseSubMap(tensorContraction.loopsToOperandRangesMap()); +} + +// Extract the ranges from a given ViewOp or SliceOp. +// +// In the case of a ViewOp, things are simple: just traverse the indexings and +// get all the ranges (i.e. drop the indices). +// +// In the case of a SliceOp, things are trickier because we need to handle a +// potential rank-reduction: +// 1. Examine the indexing to determine if it is rank-reducing. +// 2. If it is rank-reducing, an offset of 1 is added to the dimensions such +// that `d >= slicingDim`. This is to account for the rank reduction. +// `getRootIndex` is then called on the **parent** view +static llvm::SmallVector +extractRangesFromViewOrSliceOp(mlir::Value *view) { + // This expects a viewType which must come from either ViewOp or SliceOp. + assert(view->getType().isa() && "expected ViewType"); + if (auto viewOp = view->getDefiningOp()->dyn_cast()) + return viewOp.getRanges(); + + auto sliceOp = view->getDefiningOp()->cast(); + unsigned slicingDim = sliceOp.getSlicingDim(); + auto *indexing = *(sliceOp.getIndexings().begin()); + bool isRankReducing = indexing->getType().isa(); + unsigned offset = 0; + llvm::SmallVector res; + res.reserve(sliceOp.getRank()); + for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) { + if (d == slicingDim && isRankReducing) + offset = 1; + auto *parentView = sliceOp.getParentView(); + auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset); + res.push_back(indexingPosPair.first); + } + return res; +} + +template +static llvm::SmallVector +getInputRanges(linalg::TensorContractionBase &tensorContraction) { + llvm::SmallVector res; + for (auto *in : tensorContraction.getInputs()) { + auto subres = extractRangesFromViewOrSliceOp(in); + res.append(subres.begin(), subres.end()); + } + return res; +} + +template +static llvm::SmallVector +getOutputRanges(linalg::TensorContractionBase &tensorContraction) { + llvm::SmallVector res; + for (auto *out : tensorContraction.getOutputs()) { + auto subres = extractRangesFromViewOrSliceOp(out); + res.append(subres.begin(), subres.end()); + } + return res; +} + +template +llvm::SmallVector linalg::getRanges( + linalg::TensorContractionBase &tensorContraction) { + llvm::SmallVector res = getInputRanges(tensorContraction); + llvm::SmallVector tmp = getOutputRanges(tensorContraction); + res.append(tmp.begin(), tmp.end()); + return res; +} + +#endif // LINALG3_TENSOROPS_INL_H_ diff --git a/mlir/tutorial/Linalg3/include/linalg3/TensorOps.h b/mlir/tutorial/Linalg3/include/linalg3/TensorOps.h index 3dffd6ef5f3a..4ade1925bae6 100644 --- a/mlir/tutorial/Linalg3/include/linalg3/TensorOps.h +++ b/mlir/tutorial/Linalg3/include/linalg3/TensorOps.h @@ -20,6 +20,32 @@ #include "linalg2/TensorOps.h" +namespace linalg { + +/// +/// Ideally all these functions would go in an Analysis but until +/// TensorContractionBase is templated, they need to remain close enough. +/// + +/// Takes a `tensorContraction` and a returns an AffineMap that can be used to +/// map ranges to enclosing loops for all the operands' ranges. +template +mlir::AffineMap operandRangesToLoopsMap( + linalg::TensorContractionBase &tensorContraction); + +/// Takes a `tensorContraction` and returns the ranges of all its operands. +/// When an operand comes from a ViewOp, things are simple: +/// just traverse the indexings and get all the ranges +/// (i.e. drop the rank-reducing indices). +/// In the case of a SliceOp, things are more involved because we need to handle +/// potential rank-reductions. +/// This function abstracts this complexity away and returns all the ranges. +template +llvm::SmallVector +getRanges(linalg::TensorContractionBase &tensorContraction); + +} // namespace linalg + /// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of /// TensorOps by adding implementations as they are needed in the appropriate /// step in the tutorial. diff --git a/mlir/tutorial/Linalg3/include/linalg3/Transforms.h b/mlir/tutorial/Linalg3/include/linalg3/Transforms.h index b5e11ddd0db2..5cc76926a1c2 100644 --- a/mlir/tutorial/Linalg3/include/linalg3/Transforms.h +++ b/mlir/tutorial/Linalg3/include/linalg3/Transforms.h @@ -30,10 +30,17 @@ namespace linalg { /// to only use linalg.view operations. void composeSliceOps(mlir::Function *f); -/// Traverses `f` and rewrites linalg.matmul (resp. linalg.matvec, linalg.dot) -/// as linalg.matvec (resp. linalg.dot, loop form). +/// Traverses `f` and rewrites linalg.load and linalg.store to affine.load and +/// affine.store operations. +void lowerLinalgLoadStores(mlir::Function *f); + +/// Traverses `f` and rewrites linalg.matmul (resp. linalg.matvec) +/// as linalg.matvec (resp. linalg.dot). void lowerToFinerGrainedTensorContraction(mlir::Function *f); +/// Traverses `f` and rewrites linalg operations in loop form. +void lowerToLoops(mlir::Function *f); + } // namespace linalg #endif // LINALG3_TRANSFORMS_H_ diff --git a/mlir/tutorial/Linalg3/lib/Analysis.cpp b/mlir/tutorial/Linalg3/lib/Analysis.cpp new file mode 100644 index 000000000000..9e7c8eee5a03 --- /dev/null +++ b/mlir/tutorial/Linalg3/lib/Analysis.cpp @@ -0,0 +1,62 @@ +//===- Analysis.cpp - Implementation of analysis functions for Linalg -----===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a simple IR operation to create a new RangeType in the +// linalg dialect. +// +//===----------------------------------------------------------------------===// + +#include "linalg3/Analysis.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/StandardTypes.h" + +using llvm::SmallVector; +using namespace mlir; + +// Compute an inverse map (only works with permutations for now). +// Note that the mapping is generally non-full rank, so this returns the first +// seen entry for each dim. +static AffineMap inversePermutationMap(AffineMap map) { + SmallVector exprs(map.getNumDims()); + for (auto en : llvm::enumerate(map.getResults())) { + auto expr = en.value(); + auto d = expr.dyn_cast(); + assert(d && "permutation map expected"); + if (exprs[d.getPosition()]) + continue; + exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext()); + } + SmallVector seenExprs; + seenExprs.reserve(map.getNumDims()); + for (auto expr : exprs) + if (expr) + seenExprs.push_back(expr); + assert(map.getNumSymbols() == 0 && "expected map without symbols"); + assert(seenExprs.size() == map.getNumInputs() && "map is not invertible"); + return AffineMap::get(map.getNumResults(), 0, seenExprs, {}); +} + +mlir::AffineMap linalg::inverseSubMap(AffineMap map, unsigned beginResult, + unsigned endResult) { + if (beginResult == 0 && endResult == 0) + endResult = map.getNumResults(); + auto subMap = AffineMap::get( + map.getNumDims(), map.getNumSymbols(), + map.getResults().slice(beginResult, endResult - beginResult), {}); + return inversePermutationMap(subMap); +} diff --git a/mlir/tutorial/Linalg3/lib/DialectRegistration.cpp b/mlir/tutorial/Linalg3/lib/DialectRegistration.cpp new file mode 100644 index 000000000000..1ab27516846d --- /dev/null +++ b/mlir/tutorial/Linalg3/lib/DialectRegistration.cpp @@ -0,0 +1,39 @@ +//===- DialectRegistration.cpp - Registration of the Linalg dialect -------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file registers the Linalg dialect and should live in a standalone +// library. Linking with this library will create a static global object that +// performs dialect registration. +// +//===----------------------------------------------------------------------===// + +#include "linalg1/Dialect.h" +#include "linalg1/Types.h" +#include "linalg3/Ops.h" + +using namespace linalg; + +LinalgDialect::LinalgDialect(mlir::MLIRContext *context) + : Dialect("linalg", context) { + addTypes(); + addOperations(); +} + +// Dialect registration triggers the creation of a `LinalgDialect` object which +// adds the proper types and operations to the dialect. +static mlir::DialectRegistration LinalgOps; diff --git a/mlir/tutorial/Linalg3/lib/LoadStoreOps.cpp b/mlir/tutorial/Linalg3/lib/LoadStoreOps.cpp new file mode 100644 index 000000000000..340916f013b7 --- /dev/null +++ b/mlir/tutorial/Linalg3/lib/LoadStoreOps.cpp @@ -0,0 +1,136 @@ +//===- LoadStoreOps.cpp - Implementation of linalg Load/Store operations --===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements linalg.load and linalg.store operations which allow +// accessing memory through ViewType values. +// +//===----------------------------------------------------------------------===// + +#include "linalg3/LoadStoreOps.h" +#include "linalg3/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" + +using llvm::ArrayRef; +using namespace mlir; +using namespace linalg; + +//////////////////////////////////////////////////////////////////////////////// +// LoadOp. +//////////////////////////////////////////////////////////////////////////////// +void linalg::LoadOp::build(Builder *b, OperationState *result, Value *view, + ArrayRef indices) { + auto viewType = view->getType().cast(); + result->addOperands(view); + result->addOperands(indices); + result->addTypes(viewType.getElementType()); +} + +void linalg::LoadOp::print(OpAsmPrinter *p) { + *p << getOperationName() << " " << *getView() << '['; + p->printOperands(getIndices()); + *p << ']'; + p->printOptionalAttrDict(getAttrs()); + *p << " : " << getViewType(); +} + +bool linalg::LoadOp::parse(OpAsmParser *parser, OperationState *result) { + llvm_unreachable("Parsing linalg dialect is not supported in this tutorial"); + return false; +} + +LogicalResult linalg::LoadOp::verify() { + if (getNumOperands() == 0) + return emitOpError("expected a view to load from"); + + auto viewType = getView()->getType().dyn_cast(); + if (!viewType) + return emitOpError("first operand must be a view"); + + if (getType() != viewType.getElementType()) + return emitOpError("result type must match element type of the view"); + + if (getRank() != getNumOperands() - 1) + return emitOpError("incorrect number of indices for load"); + + for (auto *idx : getIndices()) + if (!idx->getType().isIndex()) + return emitOpError("index to load must have 'index' type"); + + return success(); +} + +ViewType linalg::LoadOp::getViewType() { + return getView()->getType().cast(); +} + +unsigned linalg::LoadOp::getRank() { return getViewType().getRank(); } + +//////////////////////////////////////////////////////////////////////////////// +// StoreOp. +//////////////////////////////////////////////////////////////////////////////// +void linalg::StoreOp::build(Builder *b, OperationState *result, + Value *valueToStore, Value *view, + ArrayRef indices) { + result->addOperands(valueToStore); + result->addOperands(view); + result->addOperands(indices); +} + +void linalg::StoreOp::print(OpAsmPrinter *p) { + *p << getOperationName() << " " << *getValueToStore(); + *p << ", " << *getView() << '['; + p->printOperands(getIndices()); + *p << ']'; + p->printOptionalAttrDict(getAttrs()); + *p << " : " << getViewType(); +} + +bool linalg::StoreOp::parse(OpAsmParser *parser, OperationState *result) { + assert(false && "NYI"); + return false; +} + +LogicalResult linalg::StoreOp::verify() { + if (getNumOperands() < 2) + return emitOpError("expected a value to store and a view"); + + // Second operand is a memref type. + auto viewType = getView()->getType().dyn_cast(); + if (!viewType) + return emitOpError("second operand must be a view"); + + // First operand must have same type as memref element type. + if (getValueToStore()->getType() != viewType.getElementType()) + return emitOpError("first operand must have same element type as the view"); + + if (getNumOperands() != 2 + viewType.getRank()) + return emitOpError("store index operand count not equal to view rank"); + + for (auto *idx : getIndices()) + if (!idx->getType().isIndex()) + return emitOpError("index to store must have 'index' type"); + + return success(); +} + +unsigned linalg::StoreOp::getRank() { return getViewType().getRank(); } + +ViewType linalg::StoreOp::getViewType() { + return getView()->getType().cast(); +} diff --git a/mlir/tutorial/Linalg3/lib/TensorOps.cpp b/mlir/tutorial/Linalg3/lib/TensorOps.cpp index a04d772a7cda..61eaa06c0dff 100644 --- a/mlir/tutorial/Linalg3/lib/TensorOps.cpp +++ b/mlir/tutorial/Linalg3/lib/TensorOps.cpp @@ -22,7 +22,7 @@ #include "linalg1/Analysis.h" #include "linalg1/Common.h" -#include "linalg2/Intrinsics.h" +#include "linalg3/Intrinsics.h" #include "linalg3/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpDefinition.h" @@ -32,23 +32,121 @@ using namespace mlir; using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; using namespace linalg; using namespace linalg::intrinsics; +////////////////////////////////////////////////////////////////////////////// +// Implementation of DotOp. +////////////////////////////////////////////////////////////////////////////// +AffineMap linalg::DotOp::loopsToOperandRangesMap() { + // A(K), B(K), C() + assert(getRanges(*this).size() == 2); + auto *context = ScopedContext::getContext(); + auto d0 = getAffineDimExpr(0, context); // K + // A(K), B(K), C() + // (d0) -> (d0, d0)(%k) + return AffineMap::get(1, 0, {d0, d0}, {}); +} + +void linalg::DotOp::emitScalarImplementation( + llvm::ArrayRef parallelIvs, llvm::ArrayRef reductionIvs) { + using IndexedValue = TemplatedIndexedValue; + assert(reductionIvs.size() == 1); + auto innermostLoop = getForInductionVarOwner(reductionIvs.back()); + auto *body = innermostLoop.getBody(); + using edsc::op::operator+; + using edsc::op::operator*; + using edsc::op::operator==; + using edsc::intrinsics::select; + ScopedContext scope( // account for affine.terminator in loop. + FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc()); + auto f32 = ScopedContext::getBuilder()->getF32Type(); + IndexHandle zero(constant_index(0)); + ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32); + IndexHandle r_i(reductionIvs[0]); + IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2)); + ValueHandle cond = (r_i == zero); + ValueHandle scalarC = select(cond, zerof, *C()); + C() = scalarC + A(r_i) * B(r_i); +} + +////////////////////////////////////////////////////////////////////////////// +// Implementation of MatvecOp. +////////////////////////////////////////////////////////////////////////////// +AffineMap linalg::MatvecOp::loopsToOperandRangesMap() { + // A(M, K), B(K), C(M) + assert(getRanges(*this).size() == 4); + auto *context = ScopedContext::getContext(); + auto d0 = getAffineDimExpr(0, context); // M + auto d1 = getAffineDimExpr(1, context); // K + // A(M, K), B(K), C(M) + // (d0, d1) -> (d0, d1, d1, d0)(%m, %k) + return AffineMap::get(2, 0, {d0, d1, d1, d0}, {}); +} + // The body expression for matvec is: C(i) = scalarC + A(i, r_j) * B(r_j) // The body expression for dot is: C() = A(r_i) * B(r_i); // So we must drop the `i` loop from the matvec. void linalg::MatvecOp::writeAsFinerGrainTensorContraction() { auto *op = getOperation(); - ScopedContext scope(FuncBuilder(op), op->getLoc()); - IndexHandle i; auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0)); auto indexingPosPair = getViewRootIndexing(vA, 0); assert(indexingPosPair.first->getDefiningOp() && indexingPosPair.first->getDefiningOp()->isa()); - linalg::common::LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))({ - dot(slice(vA, i, 0), vB, slice(vC, i, 0)), + // clang-format off + ScopedContext scope(FuncBuilder(op), op->getLoc()); + IndexHandle i; + using linalg::common::LoopNestRangeBuilder; + LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))({ + [&i, &vA, &vB, &vC]() { + ValueHandle sliceA = slice(vA, i, 0); + ValueHandle sliceC = slice(vC, i, 0); + dot(sliceA, vB, sliceC); + /// NestedBuilders expect handles, we thus return an IndexHandle. + return IndexHandle(); + }() }); + // clang-format on +} + +void linalg::MatvecOp::emitScalarImplementation( + llvm::ArrayRef parallelIvs, llvm::ArrayRef reductionIvs) { + using IndexedValue = TemplatedIndexedValue; + assert(reductionIvs.size() == 1); + auto innermostLoop = getForInductionVarOwner(reductionIvs.back()); + auto *body = innermostLoop.getBody(); + using edsc::op::operator+; + using edsc::op::operator*; + using edsc::op::operator==; + using edsc::intrinsics::select; + ScopedContext scope( // account for affine.terminator in loop. + FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc()); + auto f32 = ScopedContext::getBuilder()->getF32Type(); + IndexHandle i(parallelIvs[0]), r_j(reductionIvs[0]); + IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2)); + IndexHandle zero(constant_index(0)); + ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32); + ValueHandle cond = (r_j == zero); + ValueHandle scalarC = select(cond, zerof, *C(i)); + C(i) = scalarC + A(i, r_j) * B(r_j); +} + +////////////////////////////////////////////////////////////////////////////// +// Op-specific Matmul. +////////////////////////////////////////////////////////////////////////////// +AffineMap linalg::MatmulOp::loopsToOperandRangesMap() { + // A(M, K), B(K, N), C(M, N) + assert(getRanges(*this).size() == 6); + auto *context = ScopedContext::getContext(); + auto d0 = getAffineDimExpr(0, context); // M + auto d1 = getAffineDimExpr(1, context); // N + auto d2 = getAffineDimExpr(2, context); // K + // A(M, K), B(K, N), C(M, N): + // (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k) + return AffineMap::get(3, 0, {d0, d2, d2, d1, d0, d1}, {}); } // The body expression for matmul is: C(i, j) = scalarC + A(i, r_k) * B(r_k, j) @@ -58,13 +156,45 @@ void linalg::MatvecOp::writeAsFinerGrainTensorContraction() { // declaratively. void linalg::MatmulOp::writeAsFinerGrainTensorContraction() { auto *op = getOperation(); - ScopedContext scope(FuncBuilder(op), op->getLoc()); - IndexHandle j; auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0)); auto indexingPosPair = getViewRootIndexing(vB, 1); assert(indexingPosPair.first->getDefiningOp() && indexingPosPair.first->getDefiningOp()->isa()); - linalg::common::LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))({ - matvec(vA, slice(vB, j, 1), slice(vC, j, 1)), + using linalg::common::LoopNestRangeBuilder; + // clang-format off + ScopedContext scope(FuncBuilder(op), op->getLoc()); + IndexHandle j; + LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))({ + [&j, &vA, &vB, &vC]() { + ValueHandle sliceB = slice(vB, j, 1); + ValueHandle sliceC = slice(vC, j, 1); + matvec(vA, sliceB, sliceC); + /// NestedBuilders expect handles, we thus return an IndexHandle. + return IndexHandle(); + }() }); + // clang-format on +} + +void linalg::MatmulOp::emitScalarImplementation( + llvm::ArrayRef parallelIvs, llvm::ArrayRef reductionIvs) { + using IndexedValue = TemplatedIndexedValue; + assert(reductionIvs.size() == 1); + auto innermostLoop = getForInductionVarOwner(reductionIvs.back()); + auto *body = innermostLoop.getBody(); + using edsc::op::operator+; + using edsc::op::operator*; + using edsc::op::operator==; + using edsc::intrinsics::select; + ScopedContext scope( // account for affine.terminator in loop. + FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc()); + auto f32 = ScopedContext::getBuilder()->getF32Type(); + IndexHandle i(parallelIvs[0]), j(parallelIvs[1]), r_k(reductionIvs[0]); + IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2)); + IndexHandle zero(constant_index(0)); + ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32); + ValueHandle cond = r_k == zero; + ValueHandle scalarC = select(cond, zerof, *C(i, j)); + C(i, j) = scalarC + A(i, r_k) * B(r_k, j); } diff --git a/mlir/tutorial/Linalg3/lib/Transforms.cpp b/mlir/tutorial/Linalg3/lib/Transforms.cpp index aa9fbd062159..070ef5e6d0c3 100644 --- a/mlir/tutorial/Linalg3/lib/Transforms.cpp +++ b/mlir/tutorial/Linalg3/lib/Transforms.cpp @@ -53,3 +53,171 @@ void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) { op->erase(); }); } + +// Folding eagerly is necessary to abide by affine.for static step requirement. +// Returns nullptr if folding is not trivially feasible. +static Value *tryFold(AffineMap map, SmallVector operands) { + assert(map.getNumResults() == 1 && "single result map expected"); + auto expr = map.getResult(0); + if (auto dim = expr.dyn_cast()) + return operands[dim.getPosition()]; + if (auto sym = expr.dyn_cast()) + return operands[map.getNumDims() + sym.getPosition()]; + if (auto cst = expr.dyn_cast()) + return constant_index(cst.getValue()); + return nullptr; +} + +static Value *makeFoldedComposedAffineApply(AffineMap map, + ArrayRef operandsRef) { + SmallVector operands(operandsRef.begin(), operandsRef.end()); + fullyComposeAffineMapAndOperands(&map, &operands); + if (auto *v = tryFold(map, operands)) { + return v; + } + auto *b = ScopedContext::getBuilder(); + auto loc = ScopedContext::getLocation(); + return b->create(loc, map, operands).getResult(); +} + +struct RangeParts { + explicit RangeParts(unsigned reserved); + RangeParts(ArrayRef ranges); + + SmallVector makeRanges(); + + SmallVector mins; + SmallVector maxes; + SmallVector steps; +}; + +RangeParts::RangeParts(unsigned reserved) { + mins.reserve(reserved); + maxes.reserve(reserved); + steps.reserve(reserved); +} + +static SmallVector +extractFromRanges(ArrayRef ranges, + std::function extract) { + SmallVector res; + res.reserve(ranges.size()); + for (auto *v : ranges) { + auto r = v->getDefiningOp()->cast(); + res.push_back(extract(r)); + } + return res; +} + +RangeParts::RangeParts(ArrayRef ranges) + : mins(extractFromRanges(ranges, [](RangeOp r) { return r.getMin(); })), + maxes(extractFromRanges(ranges, [](RangeOp r) { return r.getMax(); })), + steps(extractFromRanges(ranges, [](RangeOp r) { return r.getStep(); })) {} + +SmallVector RangeParts::makeRanges() { + SmallVector res; + res.reserve(mins.size()); + for (auto z : llvm::zip(mins, maxes, steps)) { + res.push_back(range(std::get<0>(z), std::get<1>(z), std::get<2>(z))); + } + return res; +} + +static RangeParts makeGenericRangeParts(AffineMap map, + ArrayRef ranges) { + assert(map.getNumInputs() == ranges.size()); + unsigned numDims = map.getNumDims(); + assert(map.getNumSymbols() == 0); + assert(map.getRangeSizes().empty()); + + RangeParts res(map.getNumResults()); + RangeParts rangeParts(ranges); + for (auto expr : map.getResults()) { + AffineMap map = AffineMap::get(numDims, 0, expr, {}); + res.mins.push_back(makeFoldedComposedAffineApply(map, rangeParts.mins)); + res.maxes.push_back(makeFoldedComposedAffineApply(map, rangeParts.maxes)); + res.steps.push_back(makeFoldedComposedAffineApply(map, rangeParts.steps)); + } + return res; +} + +SmallVector makeGenericRanges(AffineMap map, + ArrayRef ranges) { + return makeGenericRangeParts(map, ranges).makeRanges(); +} + +static SmallVector makeGenericLoopRanges( + AffineMap operandRangesToLoopsMap, ArrayRef ranges, + llvm::Optional> tileSizes = llvm::None) { + RangeParts res = makeGenericRangeParts(operandRangesToLoopsMap, ranges); + if (!tileSizes.hasValue()) + return res.makeRanges(); + SmallVector tiledSteps; + for (auto z : llvm::zip(res.steps, *tileSizes)) { + auto *step = std::get<0>(z); + auto tileSize = std::get<1>(z); + auto stepValue = step->getDefiningOp()->cast().getValue(); + auto tileSizeValue = + tileSize->getDefiningOp()->cast().getValue(); + assert(stepValue > 0); + tiledSteps.push_back(constant_index(stepValue * tileSizeValue)); + } + res.steps = tiledSteps; + return res.makeRanges(); +} + +template +static SmallVector +writeAsLoops(ContractionOp contraction) { + ScopedContext scope(mlir::FuncBuilder(contraction.getOperation()), + contraction.getLoc()); + auto loopRanges = makeGenericLoopRanges(operandRangesToLoopsMap(contraction), + getRanges(contraction)); + + SmallVector parallelIvs(contraction.getNumParallelDims()); + SmallVector reductionIvs(contraction.getNumReductionDims()); + auto pivs = IndexHandle::makeIndexHandlePointers(parallelIvs); + auto rivs = IndexHandle::makeIndexHandlePointers(reductionIvs); + assert(loopRanges.size() == pivs.size() + rivs.size()); + + // clang-format off + using linalg::common::LoopNestRangeBuilder; + ArrayRef ranges(loopRanges); + LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))({ + LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))({ + [&contraction, ¶llelIvs, &reductionIvs]() { + SmallVector parallel( + parallelIvs.begin(), parallelIvs.end()); + SmallVector reduction( + reductionIvs.begin(), reductionIvs.end()); + contraction.emitScalarImplementation(parallel, reduction); + /// NestedBuilders expect handles, we thus return an IndexHandle. + return IndexHandle(); + }() + }) + }); + // clang-format on + + SmallVector res; + res.reserve(pivs.size() + rivs.size()); + for (auto iv : parallelIvs) + res.push_back(getForInductionVarOwner(iv.getValue())); + for (auto iv : reductionIvs) + res.push_back(getForInductionVarOwner(iv.getValue())); + return res; +} + +void linalg::lowerToLoops(mlir::Function *f) { + f->walkPostOrder([](Operation *op) { + if (auto matmulOp = op->dyn_cast()) { + writeAsLoops(matmulOp); + } else if (auto matvecOp = op->dyn_cast()) { + writeAsLoops(matvecOp); + } else if (auto dotOp = op->dyn_cast()) { + writeAsLoops(dotOp); + } else { + return; + } + op->erase(); + }); +}