Linalg portion of the tutorial - part 3-2

This CL adds support for lowering tensor contractions to loops declaratively.
    This is done thanks to two properties of the such operations:
    1. the definition of an AffineMap getLoopsToOperandRangesMap for each op which maps iteration space dimensions to ranges of the view operands, in their order of occurrence;
    2. the definition of a scalar implementation for each op which creates the computation inside the loops given enclosing parallel and reduction loops,

    All the other properties are derived in a generic fashion from these 2 properties and a few analyses.

    A lowerToLoops transformation is added as well as a test that exercises it.

--

PiperOrigin-RevId: 241783992
This commit is contained in:
Nicolas Vasilache 2019-04-03 12:33:01 -07:00 committed by Mehdi Amini
parent b9e3b2107b
commit 50df91745d
14 changed files with 978 additions and 29 deletions

View File

@ -41,6 +41,7 @@ protected:
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
public:
//////////////////////////////////////////////////////////////////////////////
// Op-specific functionality.
//////////////////////////////////////////////////////////////////////////////
@ -48,7 +49,6 @@ protected:
mlir::Operation::operand_range getInputs();
mlir::Operation::operand_range getOutputs();
public:
/// These are better as methods calling into the ConcreteOp instead of
/// template parameters because methods allow more generic behavior and avoid
/// specializing for number of arguments. All derived classes have
@ -72,14 +72,24 @@ public:
//////////////////////////////////////////////////////////////////////////////
mlir::Value *getInputView(unsigned i);
mlir::Value *getOutputView(unsigned i);
/// Computes a mapping from all the ranges of the operands to the enclosing
/// loops. In order to support "broadcast"-style semantics, we need to
/// consider all the operands (i.e. input operands are not sufficient).
/// Each op is responsible for declaring how it lowers itself to scalar form,
/// given the enclosing parallel and reduction induction variables.
/// `emitScalarImplementation` emits the scalar IR for the op in the nesting
/// context of the innermost enclosing loop(i.e. `reductionIvs.back()` or
/// `parallel.back()`).
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
llvm::ArrayRef<mlir::Value *> reductionIvs);
/// Represents a mapping from the loops to all the ranges of the operands.
/// The operands and their ranges are in the order defined by the particular
/// ConcreteOp implementation, the resulting map must match those.
/// This is currently computed but can also be specified explicitly in each
/// operator to generalize to cases where an analysis is not available.
mlir::AffineMap operandRangesToLoopsMap();
/// In favorable cases, this can be calculated by an analysis but specifying
/// it explicitly is not expensive and generalizes to cases where an analysis
/// is not available.
/// For details, see the description of loopsToOperandRangesMap in each
/// ConcreteOp
mlir::AffineMap loopsToOperandRangesMap();
};
/// Implements c = A * B where c is a scalar and A and B are 1-D vectors.
@ -119,6 +129,27 @@ public:
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
/// loop over matvec). Does nothing by default.
void writeAsFinerGrainTensorContraction();
/// Inputs to this map will be (%k) coming from enclosing loops.
/// Therefore, the mapping to get back to A(K), B(K), C() is:
/// (d0) -> (d0, d0)(%k)
/// And the operands ranges are:
/// (%k, %k)
mlir::AffineMap loopsToOperandRangesMap();
/// Given an enclosing reduction loop with iv `r_i`, emits MLIR corresponding
/// to:
/// 1. conditionally assign scalarC to 0.0f on the first iteration or load
/// C[] from memory (0-D tensor)
/// 2. multiply A[r_i] by B[r_i] and add to scalarC
/// 3. store back scalarC at C[]
///
/// In some compact index notation this could be written:
/// cond = (r_i == zero)
/// scalarC = select(cond, zerof, C[]);
/// C[] = scalarC + A[r_i] * B[r_i];
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
llvm::ArrayRef<mlir::Value *> reductionIvs);
};
/// Implements C = A * B where A is a 2-D matrix and X and Y are 1-D vectors.
@ -158,6 +189,27 @@ public:
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
/// loop over matvec). Does nothing by default.
void writeAsFinerGrainTensorContraction();
/// Inputs to this map will be (%m, %k) coming from enclosing loops.
/// Therefore, the mapping to get back to A(M, K), B(K), C(M) is:
/// (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
/// And the operands ranges are:
/// (%m, %k, %k, %m)
mlir::AffineMap loopsToOperandRangesMap();
/// Given an enclosing parallel loop with iv `i` and an enclosing parallel
/// loop with iv `r_j`, emits MLIR corresponding to:
/// 1. conditionally assign scalarC to 0.0f on the first iteration or load
/// C[i]
/// 2. multiply A[i, r_j] by B[r_j] and add to scalarC
/// 3. store back scalarC at C[i]
///
/// In some compact index notation this could be written:
/// cond = (r_j == zero)
/// scalarC = select(cond, zerof, C(i));
/// C(i) = scalarC + A(i, r_j) * B(r_j);
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
llvm::ArrayRef<mlir::Value *> reductionIvs);
};
/// Implements C = A * B on 2-D matrices.
@ -197,6 +249,27 @@ public:
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
/// loop over matvec). Does nothing by default.
void writeAsFinerGrainTensorContraction();
/// Inputs to this map will be (%m, %n, %k) coming from enclosing loops.
/// Therefore, the mapping to get back to A(M, K), B(K, N), C(M, N) is:
/// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
/// And the operands ranges are:
/// (%m, %k, %k, %n, %m, %n)
mlir::AffineMap loopsToOperandRangesMap();
/// Given a enclosing parallel loops with ivs `i` and `j`, and an enclosing
/// reduction loop with iv `r_k`, emits MLIR corresponding to:
/// 1. conditionally assign scalarC to 0.0f on the first iteration or load
/// C[i, j]
/// 2. multiply A[i, r_k] by B[r_k, j] and add to scalarC
/// 3. store back scalarC at C[i, j]
///
/// In some compact index notation this could be written:
/// cond = (r_k == zero)
/// scalarC = select(cond, zerof, C[i, j]);
/// C[i, j] = scalarC + A[i, r_k] * B[r_k, j];
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
llvm::ArrayRef<mlir::Value *> reductionIvs);
};
} // namespace linalg

View File

@ -64,13 +64,15 @@ TEST_FUNC(matmul_as_matvec) {
Module module(&context);
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec");
lowerToFinerGrainedTensorContraction(f);
composeSliceOps(f);
// clang-format off
// CHECK-LABEL: func @matmul_as_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
// CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
// CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view<f32xf32>">
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
// CHECK-NEXT: %[[vB:.*]] = linalg.slice %{{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
// CHECK-NEXT: %[[vC:.*]] = linalg.slice %{{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
// CHECK-NEXT: linalg.matvec {%{{.*}}, %[[vB]]} -> {%[[vC]]}
// CHECK: %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
// CHECK: %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
// CHECK: linalg.matvec {%[[vA]], %[[vB]]} -> {%[[vC]]}
// clang-format on
cleanupAndPrintFunction(f);
}
@ -81,21 +83,84 @@ TEST_FUNC(matmul_as_dot) {
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_dot");
lowerToFinerGrainedTensorContraction(f);
lowerToFinerGrainedTensorContraction(f);
composeSliceOps(f);
// clang-format off
// CHECK-LABEL: func @matmul_as_dot(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
// CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
// CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
// CHECK-NEXT: %[[vB:.*]] = linalg.slice {{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
// CHECK-NEXT: %[[sC:.*]] = linalg.slice {{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
// CHECK: %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
// CHECK-NEXT: affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) {
// CHECK-NEXT: %[[vA:.*]] = linalg.slice {{.*}}[%i1, *] { dim : 0 } : !linalg<"view<f32>">
// CHECK-NEXT: %[[vC:.*]] = linalg.slice %[[sC]][%i1] { dim : 0 } : !linalg<"view<0xf32>">
// CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
// CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view<0xf32>">
// CHECK-NEXT: linalg.dot {%[[vA]], %[[vB]]} -> {%[[vC]]}
// clang-format on
cleanupAndPrintFunction(f);
}
TEST_FUNC(matmul_as_loops) {
MLIRContext context;
Module module(&context);
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
lowerToLoops(f);
composeSliceOps(f);
// clang-format off
// CHECK-LABEL: func @matmul_as_loops(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
// CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
// CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
// CHECK: %[[K:.*]] = dim %arg0, 1 : memref<?x?xf32>
// CHECK: %[[rM:.*]] = linalg.range %c0:%[[M]]:%c1 : !linalg<"range">
// CHECK: %[[rN:.*]] = linalg.range %c0:%[[N]]:%c1 : !linalg<"range">
// CHECK: %[[rK:.*]] = linalg.range %c0:%[[K]]:%c1 : !linalg<"range">
// CHECK: %[[vA:.*]] = linalg.view %arg0[%[[rM]], %[[rK]]] : !linalg<"view<f32xf32>">
// CHECK: %[[vB:.*]] = linalg.view %arg1[%[[rK]], %[[rN]]] : !linalg<"view<f32xf32>">
// CHECK: %[[vC:.*]] = linalg.view %arg2[%[[rM]], %[[rN]]] : !linalg<"view<f32xf32>">
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) {
// CHECK: affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) {
// CHECK: affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) {
// CHECK: %{{.*}} = cmpi "eq", %{{.*}} : index
// CHECK: %{{.*}} = linalg.load %[[vC]][%i0, %i1] : !linalg<"view<f32xf32>">
// CHECK: %{{.*}} = select {{.*}} : f32
// CHECK: %{{.*}} = linalg.load %[[vB]][%i2, %i1] : !linalg<"view<f32xf32>">
// CHECK: %{{.*}} = linalg.load %[[vA]][%i0, %i2] : !linalg<"view<f32xf32>">
// CHECK: %{{.*}} = mulf {{.*}} : f32
// CHECK: %{{.*}} = addf {{.*}} : f32
// CHECK: linalg.store {{.*}}[%i0, %i1] : !linalg<"view<f32xf32>">
// clang-format on
cleanupAndPrintFunction(f);
}
TEST_FUNC(matmul_as_matvec_as_loops) {
MLIRContext context;
Module module(&context);
mlir::Function *f =
makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_loops");
lowerToFinerGrainedTensorContraction(f);
lowerToLoops(f);
composeSliceOps(f);
// clang-format off
// CHECK-LABEL: func @matmul_as_matvec_as_loops(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
// CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
// CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
// CHECK: %[[K:.*]] = dim %arg0, 1 : memref<?x?xf32>
// CHECK: %[[vA:.*]] = linalg.view %arg0[{{.*}}, {{.*}}] : !linalg<"view<f32xf32>">
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
// CHECK: %[[vB:.*]] = linalg.view %arg1[{{.*}}, {{.*}}] : !linalg<"view<f32>">
// CHECK: %[[vC:.*]] = linalg.view %arg2[{{.*}}, {{.*}}] : !linalg<"view<f32>">
// CHECK: affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) {
// CHECK: affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) {
// CHECK: %{{.*}} = cmpi "eq", %i2, %{{.*}} : index
// CHECK: %[[C:.*]] = linalg.load %[[vC]][%i1] : !linalg<"view<f32>">
// CHECK: %[[C2:.*]] = select %{{.*}}, %{{.*}}, %[[C]] : f32
// CHECK: %[[B:.*]] = linalg.load %[[vB]][%i2] : !linalg<"view<f32>">
// CHECK: %[[A:.*]] = linalg.load %[[vA]][%i1, %i2] : !linalg<"view<f32xf32>">
// CHECK: %{{.*}} = mulf %[[A]], %[[B]] : f32
// CHECK: %{{.*}} = addf %[[C2]], %{{.*}} : f32
// CHECK: linalg.store %{{.*}}, %{{.*}}[%i1] : !linalg<"view<f32>">
// clang-format on
cleanupAndPrintFunction(f);
}
int main() {
RUN_TESTS();
return 0;

View File

@ -0,0 +1,37 @@
//===- Analysis.h - Linalg dialect Analysis function definitions ----------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef LINALG3_ANALYSIS_H_
#define LINALG3_ANALYSIS_H_
#include "linalg2/Analysis.h"
namespace mlir {
class AffineMap;
} // namespace mlir
namespace linalg {
/// Given a `map` specification and a subset of its results
/// `[beginResult, endResult)`, returns the inverse map that maps result
/// positions to dim positions.
mlir::AffineMap inverseSubMap(mlir::AffineMap map, unsigned beginResult = 0,
unsigned endResult = 0);
} // namespace linalg
#endif // LINALG3_ANALYSIS_H_

View File

@ -0,0 +1,31 @@
//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef LINALG3_INTRINSICS_H_
#define LINALG3_INTRINSICS_H_
#include "linalg2/Intrinsics.h"
#include "linalg3/Ops.h"
namespace linalg {
namespace intrinsics {
using load = mlir::edsc::intrinsics::ValueBuilder<LoadOp>;
using store = mlir::edsc::intrinsics::OperationBuilder<StoreOp>;
} // namespace intrinsics
} // namespace linalg
#endif // LINALG3_INTRINSICS_H_

View File

@ -0,0 +1,89 @@
//===- LoadStoreOps.h - Linalg dialect Load/Store operation definitions ---===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef LINALG3_LOADSTOREOP_H_
#define LINALG3_LOADSTOREOP_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"
namespace linalg {
class ViewType;
/// A linalg.LoadOp is the counterpart of affine.load but operating on ViewType
/// instead of MemRefType.
class LoadOp : public mlir::Op<LoadOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::OneResult> {
public:
using Op::Op;
//////////////////////////////////////////////////////////////////////////////
// Hooks to customize the behavior of this op.
//////////////////////////////////////////////////////////////////////////////
static llvm::StringRef getOperationName() { return "linalg.load"; }
static void build(mlir::Builder *b, mlir::OperationState *result,
mlir::Value *view,
mlir::ArrayRef<mlir::Value *> indices = {});
mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
//////////////////////////////////////////////////////////////////////////////
// Op-specific functionality.
//////////////////////////////////////////////////////////////////////////////
unsigned getRank();
ViewType getViewType();
mlir::Value *getView() { return getOperand(0); }
mlir::Operation::operand_range getIndices() {
return {operand_begin() + 1, operand_end()};
}
};
/// A linalg.StoreOp is the counterpart of affine.store but operating on
/// ViewType instead of MemRefType.
class StoreOp : public mlir::Op<StoreOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::ZeroResult> {
public:
using Op::Op;
//////////////////////////////////////////////////////////////////////////////
// Hooks to customize the behavior of this op.
//////////////////////////////////////////////////////////////////////////////
static llvm::StringRef getOperationName() { return "linalg.store"; }
static void build(mlir::Builder *b, mlir::OperationState *result,
mlir::Value *valueToStore, mlir::Value *view,
mlir::ArrayRef<mlir::Value *> indices = {});
mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
//////////////////////////////////////////////////////////////////////////////
// Op-specific functionality.
//////////////////////////////////////////////////////////////////////////////
unsigned getRank();
ViewType getViewType();
mlir::Value *getValueToStore() { return getOperand(0); }
mlir::Value *getView() { return getOperand(1); }
mlir::Operation::operand_range getIndices() {
return {operand_begin() + 2, operand_end()};
}
};
} // namespace linalg
#endif // LINALG3_LOADSTOREOP_H_

View File

@ -19,6 +19,7 @@
#define LINALG3_OPS_H_
#include "linalg2/Ops.h"
#include "linalg3/LoadStoreOps.h"
#include "linalg3/TensorOps.h"
#endif // LINALG3_OPS_H_

View File

@ -22,9 +22,10 @@
#define LINALG3_TENSOROPS_INL_H_
#include "linalg1/Common.h"
#include "linalg1/Utils.h"
#include "linalg2/TensorOps.h"
namespace linalg {
#include "linalg3/Analysis.h"
#include "linalg3/Ops.h"
template <class ConcreteOp>
mlir::Value *
@ -38,6 +39,90 @@ linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned i) {
return *(getOutputs().begin() + i);
}
} // namespace linalg
template <class ConcreteOp>
mlir::AffineMap
linalg::TensorContractionBase<ConcreteOp>::loopsToOperandRangesMap() {
return static_cast<ConcreteOp *>(this)->loopsToOperandRangesMap();
}
#endif // LINALG3_TENSOROPS-INL_H_
template <class ConcreteOp>
void linalg::TensorContractionBase<ConcreteOp>::emitScalarImplementation(
llvm::ArrayRef<mlir::Value *> parallelIvs,
llvm::ArrayRef<mlir::Value *> reductionIvs) {
static_cast<ConcreteOp *>(this)->emitScalarImplementation(parallelIvs,
reductionIvs);
}
template <class ConcreteOp>
mlir::AffineMap linalg::operandRangesToLoopsMap(
linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
return inverseSubMap(tensorContraction.loopsToOperandRangesMap());
}
// Extract the ranges from a given ViewOp or SliceOp.
//
// In the case of a ViewOp, things are simple: just traverse the indexings and
// get all the ranges (i.e. drop the indices).
//
// In the case of a SliceOp, things are trickier because we need to handle a
// potential rank-reduction:
// 1. Examine the indexing to determine if it is rank-reducing.
// 2. If it is rank-reducing, an offset of 1 is added to the dimensions such
// that `d >= slicingDim`. This is to account for the rank reduction.
// `getRootIndex` is then called on the **parent** view
static llvm::SmallVector<mlir::Value *, 8>
extractRangesFromViewOrSliceOp(mlir::Value *view) {
// This expects a viewType which must come from either ViewOp or SliceOp.
assert(view->getType().isa<linalg::ViewType>() && "expected ViewType");
if (auto viewOp = view->getDefiningOp()->dyn_cast<linalg::ViewOp>())
return viewOp.getRanges();
auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>();
unsigned slicingDim = sliceOp.getSlicingDim();
auto *indexing = *(sliceOp.getIndexings().begin());
bool isRankReducing = indexing->getType().isa<mlir::IndexType>();
unsigned offset = 0;
llvm::SmallVector<mlir::Value *, 8> res;
res.reserve(sliceOp.getRank());
for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) {
if (d == slicingDim && isRankReducing)
offset = 1;
auto *parentView = sliceOp.getParentView();
auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset);
res.push_back(indexingPosPair.first);
}
return res;
}
template <class ConcreteOp>
static llvm::SmallVector<mlir::Value *, 8>
getInputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
llvm::SmallVector<mlir::Value *, 8> res;
for (auto *in : tensorContraction.getInputs()) {
auto subres = extractRangesFromViewOrSliceOp(in);
res.append(subres.begin(), subres.end());
}
return res;
}
template <class ConcreteOp>
static llvm::SmallVector<mlir::Value *, 8>
getOutputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
llvm::SmallVector<mlir::Value *, 8> res;
for (auto *out : tensorContraction.getOutputs()) {
auto subres = extractRangesFromViewOrSliceOp(out);
res.append(subres.begin(), subres.end());
}
return res;
}
template <class ConcreteOp>
llvm::SmallVector<mlir::Value *, 8> linalg::getRanges(
linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
llvm::SmallVector<mlir::Value *, 8> res = getInputRanges(tensorContraction);
llvm::SmallVector<mlir::Value *, 8> tmp = getOutputRanges(tensorContraction);
res.append(tmp.begin(), tmp.end());
return res;
}
#endif // LINALG3_TENSOROPS_INL_H_

View File

@ -20,6 +20,32 @@
#include "linalg2/TensorOps.h"
namespace linalg {
///
/// Ideally all these functions would go in an Analysis but until
/// TensorContractionBase is templated, they need to remain close enough.
///
/// Takes a `tensorContraction` and a returns an AffineMap that can be used to
/// map ranges to enclosing loops for all the operands' ranges.
template <class ConcreteOp>
mlir::AffineMap operandRangesToLoopsMap(
linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
/// Takes a `tensorContraction` and returns the ranges of all its operands.
/// When an operand comes from a ViewOp, things are simple:
/// just traverse the indexings and get all the ranges
/// (i.e. drop the rank-reducing indices).
/// In the case of a SliceOp, things are more involved because we need to handle
/// potential rank-reductions.
/// This function abstracts this complexity away and returns all the ranges.
template <class ConcreteOp>
llvm::SmallVector<mlir::Value *, 8>
getRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
} // namespace linalg
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
/// TensorOps by adding implementations as they are needed in the appropriate
/// step in the tutorial.

View File

@ -30,10 +30,17 @@ namespace linalg {
/// to only use linalg.view operations.
void composeSliceOps(mlir::Function *f);
/// Traverses `f` and rewrites linalg.matmul (resp. linalg.matvec, linalg.dot)
/// as linalg.matvec (resp. linalg.dot, loop form).
/// Traverses `f` and rewrites linalg.load and linalg.store to affine.load and
/// affine.store operations.
void lowerLinalgLoadStores(mlir::Function *f);
/// Traverses `f` and rewrites linalg.matmul (resp. linalg.matvec)
/// as linalg.matvec (resp. linalg.dot).
void lowerToFinerGrainedTensorContraction(mlir::Function *f);
/// Traverses `f` and rewrites linalg operations in loop form.
void lowerToLoops(mlir::Function *f);
} // namespace linalg
#endif // LINALG3_TRANSFORMS_H_

View File

@ -0,0 +1,62 @@
//===- Analysis.cpp - Implementation of analysis functions for Linalg -----===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a simple IR operation to create a new RangeType in the
// linalg dialect.
//
//===----------------------------------------------------------------------===//
#include "linalg3/Analysis.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/StandardTypes.h"
using llvm::SmallVector;
using namespace mlir;
// Compute an inverse map (only works with permutations for now).
// Note that the mapping is generally non-full rank, so this returns the first
// seen entry for each dim.
static AffineMap inversePermutationMap(AffineMap map) {
SmallVector<AffineExpr, 4> exprs(map.getNumDims());
for (auto en : llvm::enumerate(map.getResults())) {
auto expr = en.value();
auto d = expr.dyn_cast<AffineDimExpr>();
assert(d && "permutation map expected");
if (exprs[d.getPosition()])
continue;
exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext());
}
SmallVector<AffineExpr, 4> seenExprs;
seenExprs.reserve(map.getNumDims());
for (auto expr : exprs)
if (expr)
seenExprs.push_back(expr);
assert(map.getNumSymbols() == 0 && "expected map without symbols");
assert(seenExprs.size() == map.getNumInputs() && "map is not invertible");
return AffineMap::get(map.getNumResults(), 0, seenExprs, {});
}
mlir::AffineMap linalg::inverseSubMap(AffineMap map, unsigned beginResult,
unsigned endResult) {
if (beginResult == 0 && endResult == 0)
endResult = map.getNumResults();
auto subMap = AffineMap::get(
map.getNumDims(), map.getNumSymbols(),
map.getResults().slice(beginResult, endResult - beginResult), {});
return inversePermutationMap(subMap);
}

View File

@ -0,0 +1,39 @@
//===- DialectRegistration.cpp - Registration of the Linalg dialect -------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file registers the Linalg dialect and should live in a standalone
// library. Linking with this library will create a static global object that
// performs dialect registration.
//
//===----------------------------------------------------------------------===//
#include "linalg1/Dialect.h"
#include "linalg1/Types.h"
#include "linalg3/Ops.h"
using namespace linalg;
LinalgDialect::LinalgDialect(mlir::MLIRContext *context)
: Dialect("linalg", context) {
addTypes<RangeType, ViewType>();
addOperations<DotOp, LoadOp, MatvecOp, MatmulOp, RangeOp, SliceOp, StoreOp,
ViewOp>();
}
// Dialect registration triggers the creation of a `LinalgDialect` object which
// adds the proper types and operations to the dialect.
static mlir::DialectRegistration<LinalgDialect> LinalgOps;

View File

@ -0,0 +1,136 @@
//===- LoadStoreOps.cpp - Implementation of linalg Load/Store operations --===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements linalg.load and linalg.store operations which allow
// accessing memory through ViewType values.
//
//===----------------------------------------------------------------------===//
#include "linalg3/LoadStoreOps.h"
#include "linalg3/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
using llvm::ArrayRef;
using namespace mlir;
using namespace linalg;
////////////////////////////////////////////////////////////////////////////////
// LoadOp.
////////////////////////////////////////////////////////////////////////////////
void linalg::LoadOp::build(Builder *b, OperationState *result, Value *view,
ArrayRef<Value *> indices) {
auto viewType = view->getType().cast<ViewType>();
result->addOperands(view);
result->addOperands(indices);
result->addTypes(viewType.getElementType());
}
void linalg::LoadOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getView() << '[';
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
*p << " : " << getViewType();
}
bool linalg::LoadOp::parse(OpAsmParser *parser, OperationState *result) {
llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
return false;
}
LogicalResult linalg::LoadOp::verify() {
if (getNumOperands() == 0)
return emitOpError("expected a view to load from");
auto viewType = getView()->getType().dyn_cast<ViewType>();
if (!viewType)
return emitOpError("first operand must be a view");
if (getType() != viewType.getElementType())
return emitOpError("result type must match element type of the view");
if (getRank() != getNumOperands() - 1)
return emitOpError("incorrect number of indices for load");
for (auto *idx : getIndices())
if (!idx->getType().isIndex())
return emitOpError("index to load must have 'index' type");
return success();
}
ViewType linalg::LoadOp::getViewType() {
return getView()->getType().cast<ViewType>();
}
unsigned linalg::LoadOp::getRank() { return getViewType().getRank(); }
////////////////////////////////////////////////////////////////////////////////
// StoreOp.
////////////////////////////////////////////////////////////////////////////////
void linalg::StoreOp::build(Builder *b, OperationState *result,
Value *valueToStore, Value *view,
ArrayRef<Value *> indices) {
result->addOperands(valueToStore);
result->addOperands(view);
result->addOperands(indices);
}
void linalg::StoreOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getValueToStore();
*p << ", " << *getView() << '[';
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
*p << " : " << getViewType();
}
bool linalg::StoreOp::parse(OpAsmParser *parser, OperationState *result) {
assert(false && "NYI");
return false;
}
LogicalResult linalg::StoreOp::verify() {
if (getNumOperands() < 2)
return emitOpError("expected a value to store and a view");
// Second operand is a memref type.
auto viewType = getView()->getType().dyn_cast<ViewType>();
if (!viewType)
return emitOpError("second operand must be a view");
// First operand must have same type as memref element type.
if (getValueToStore()->getType() != viewType.getElementType())
return emitOpError("first operand must have same element type as the view");
if (getNumOperands() != 2 + viewType.getRank())
return emitOpError("store index operand count not equal to view rank");
for (auto *idx : getIndices())
if (!idx->getType().isIndex())
return emitOpError("index to store must have 'index' type");
return success();
}
unsigned linalg::StoreOp::getRank() { return getViewType().getRank(); }
ViewType linalg::StoreOp::getViewType() {
return getView()->getType().cast<ViewType>();
}

View File

@ -22,7 +22,7 @@
#include "linalg1/Analysis.h"
#include "linalg1/Common.h"
#include "linalg2/Intrinsics.h"
#include "linalg3/Intrinsics.h"
#include "linalg3/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h"
@ -32,23 +32,121 @@
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace linalg;
using namespace linalg::intrinsics;
//////////////////////////////////////////////////////////////////////////////
// Implementation of DotOp.
//////////////////////////////////////////////////////////////////////////////
AffineMap linalg::DotOp::loopsToOperandRangesMap() {
// A(K), B(K), C()
assert(getRanges(*this).size() == 2);
auto *context = ScopedContext::getContext();
auto d0 = getAffineDimExpr(0, context); // K
// A(K), B(K), C()
// (d0) -> (d0, d0)(%k)
return AffineMap::get(1, 0, {d0, d0}, {});
}
void linalg::DotOp::emitScalarImplementation(
llvm::ArrayRef<Value *> parallelIvs, llvm::ArrayRef<Value *> reductionIvs) {
using IndexedValue = TemplatedIndexedValue<linalg::intrinsics::load,
linalg::intrinsics::store>;
assert(reductionIvs.size() == 1);
auto innermostLoop = getForInductionVarOwner(reductionIvs.back());
auto *body = innermostLoop.getBody();
using edsc::op::operator+;
using edsc::op::operator*;
using edsc::op::operator==;
using edsc::intrinsics::select;
ScopedContext scope( // account for affine.terminator in loop.
FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
auto f32 = ScopedContext::getBuilder()->getF32Type();
IndexHandle zero(constant_index(0));
ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
IndexHandle r_i(reductionIvs[0]);
IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
ValueHandle cond = (r_i == zero);
ValueHandle scalarC = select(cond, zerof, *C());
C() = scalarC + A(r_i) * B(r_i);
}
//////////////////////////////////////////////////////////////////////////////
// Implementation of MatvecOp.
//////////////////////////////////////////////////////////////////////////////
AffineMap linalg::MatvecOp::loopsToOperandRangesMap() {
// A(M, K), B(K), C(M)
assert(getRanges(*this).size() == 4);
auto *context = ScopedContext::getContext();
auto d0 = getAffineDimExpr(0, context); // M
auto d1 = getAffineDimExpr(1, context); // K
// A(M, K), B(K), C(M)
// (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
return AffineMap::get(2, 0, {d0, d1, d1, d0}, {});
}
// The body expression for matvec is: C(i) = scalarC + A(i, r_j) * B(r_j)
// The body expression for dot is: C() = A(r_i) * B(r_i);
// So we must drop the `i` loop from the matvec.
void linalg::MatvecOp::writeAsFinerGrainTensorContraction() {
auto *op = getOperation();
ScopedContext scope(FuncBuilder(op), op->getLoc());
IndexHandle i;
auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0));
auto indexingPosPair = getViewRootIndexing(vA, 0);
assert(indexingPosPair.first->getDefiningOp() &&
indexingPosPair.first->getDefiningOp()->isa<RangeOp>());
linalg::common::LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))({
dot(slice(vA, i, 0), vB, slice(vC, i, 0)),
// clang-format off
ScopedContext scope(FuncBuilder(op), op->getLoc());
IndexHandle i;
using linalg::common::LoopNestRangeBuilder;
LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))({
[&i, &vA, &vB, &vC]() {
ValueHandle sliceA = slice(vA, i, 0);
ValueHandle sliceC = slice(vC, i, 0);
dot(sliceA, vB, sliceC);
/// NestedBuilders expect handles, we thus return an IndexHandle.
return IndexHandle();
}()
});
// clang-format on
}
void linalg::MatvecOp::emitScalarImplementation(
llvm::ArrayRef<Value *> parallelIvs, llvm::ArrayRef<Value *> reductionIvs) {
using IndexedValue = TemplatedIndexedValue<linalg::intrinsics::load,
linalg::intrinsics::store>;
assert(reductionIvs.size() == 1);
auto innermostLoop = getForInductionVarOwner(reductionIvs.back());
auto *body = innermostLoop.getBody();
using edsc::op::operator+;
using edsc::op::operator*;
using edsc::op::operator==;
using edsc::intrinsics::select;
ScopedContext scope( // account for affine.terminator in loop.
FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
auto f32 = ScopedContext::getBuilder()->getF32Type();
IndexHandle i(parallelIvs[0]), r_j(reductionIvs[0]);
IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
IndexHandle zero(constant_index(0));
ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
ValueHandle cond = (r_j == zero);
ValueHandle scalarC = select(cond, zerof, *C(i));
C(i) = scalarC + A(i, r_j) * B(r_j);
}
//////////////////////////////////////////////////////////////////////////////
// Op-specific Matmul.
//////////////////////////////////////////////////////////////////////////////
AffineMap linalg::MatmulOp::loopsToOperandRangesMap() {
// A(M, K), B(K, N), C(M, N)
assert(getRanges(*this).size() == 6);
auto *context = ScopedContext::getContext();
auto d0 = getAffineDimExpr(0, context); // M
auto d1 = getAffineDimExpr(1, context); // N
auto d2 = getAffineDimExpr(2, context); // K
// A(M, K), B(K, N), C(M, N):
// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
return AffineMap::get(3, 0, {d0, d2, d2, d1, d0, d1}, {});
}
// The body expression for matmul is: C(i, j) = scalarC + A(i, r_k) * B(r_k, j)
@ -58,13 +156,45 @@ void linalg::MatvecOp::writeAsFinerGrainTensorContraction() {
// declaratively.
void linalg::MatmulOp::writeAsFinerGrainTensorContraction() {
auto *op = getOperation();
ScopedContext scope(FuncBuilder(op), op->getLoc());
IndexHandle j;
auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0));
auto indexingPosPair = getViewRootIndexing(vB, 1);
assert(indexingPosPair.first->getDefiningOp() &&
indexingPosPair.first->getDefiningOp()->isa<RangeOp>());
linalg::common::LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))({
matvec(vA, slice(vB, j, 1), slice(vC, j, 1)),
using linalg::common::LoopNestRangeBuilder;
// clang-format off
ScopedContext scope(FuncBuilder(op), op->getLoc());
IndexHandle j;
LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))({
[&j, &vA, &vB, &vC]() {
ValueHandle sliceB = slice(vB, j, 1);
ValueHandle sliceC = slice(vC, j, 1);
matvec(vA, sliceB, sliceC);
/// NestedBuilders expect handles, we thus return an IndexHandle.
return IndexHandle();
}()
});
// clang-format on
}
void linalg::MatmulOp::emitScalarImplementation(
llvm::ArrayRef<Value *> parallelIvs, llvm::ArrayRef<Value *> reductionIvs) {
using IndexedValue = TemplatedIndexedValue<linalg::intrinsics::load,
linalg::intrinsics::store>;
assert(reductionIvs.size() == 1);
auto innermostLoop = getForInductionVarOwner(reductionIvs.back());
auto *body = innermostLoop.getBody();
using edsc::op::operator+;
using edsc::op::operator*;
using edsc::op::operator==;
using edsc::intrinsics::select;
ScopedContext scope( // account for affine.terminator in loop.
FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
auto f32 = ScopedContext::getBuilder()->getF32Type();
IndexHandle i(parallelIvs[0]), j(parallelIvs[1]), r_k(reductionIvs[0]);
IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
IndexHandle zero(constant_index(0));
ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
ValueHandle cond = r_k == zero;
ValueHandle scalarC = select(cond, zerof, *C(i, j));
C(i, j) = scalarC + A(i, r_k) * B(r_k, j);
}

View File

@ -53,3 +53,171 @@ void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) {
op->erase();
});
}
// Folding eagerly is necessary to abide by affine.for static step requirement.
// Returns nullptr if folding is not trivially feasible.
static Value *tryFold(AffineMap map, SmallVector<Value *, 4> operands) {
assert(map.getNumResults() == 1 && "single result map expected");
auto expr = map.getResult(0);
if (auto dim = expr.dyn_cast<AffineDimExpr>())
return operands[dim.getPosition()];
if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
return operands[map.getNumDims() + sym.getPosition()];
if (auto cst = expr.dyn_cast<AffineConstantExpr>())
return constant_index(cst.getValue());
return nullptr;
}
static Value *makeFoldedComposedAffineApply(AffineMap map,
ArrayRef<Value *> operandsRef) {
SmallVector<Value *, 4> operands(operandsRef.begin(), operandsRef.end());
fullyComposeAffineMapAndOperands(&map, &operands);
if (auto *v = tryFold(map, operands)) {
return v;
}
auto *b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
return b->create<AffineApplyOp>(loc, map, operands).getResult();
}
struct RangeParts {
explicit RangeParts(unsigned reserved);
RangeParts(ArrayRef<Value *> ranges);
SmallVector<Value *, 4> makeRanges();
SmallVector<Value *, 4> mins;
SmallVector<Value *, 4> maxes;
SmallVector<Value *, 4> steps;
};
RangeParts::RangeParts(unsigned reserved) {
mins.reserve(reserved);
maxes.reserve(reserved);
steps.reserve(reserved);
}
static SmallVector<Value *, 4>
extractFromRanges(ArrayRef<Value *> ranges,
std::function<Value *(RangeOp)> extract) {
SmallVector<Value *, 4> res;
res.reserve(ranges.size());
for (auto *v : ranges) {
auto r = v->getDefiningOp()->cast<RangeOp>();
res.push_back(extract(r));
}
return res;
}
RangeParts::RangeParts(ArrayRef<Value *> ranges)
: mins(extractFromRanges(ranges, [](RangeOp r) { return r.getMin(); })),
maxes(extractFromRanges(ranges, [](RangeOp r) { return r.getMax(); })),
steps(extractFromRanges(ranges, [](RangeOp r) { return r.getStep(); })) {}
SmallVector<Value *, 4> RangeParts::makeRanges() {
SmallVector<Value *, 4> res;
res.reserve(mins.size());
for (auto z : llvm::zip(mins, maxes, steps)) {
res.push_back(range(std::get<0>(z), std::get<1>(z), std::get<2>(z)));
}
return res;
}
static RangeParts makeGenericRangeParts(AffineMap map,
ArrayRef<Value *> ranges) {
assert(map.getNumInputs() == ranges.size());
unsigned numDims = map.getNumDims();
assert(map.getNumSymbols() == 0);
assert(map.getRangeSizes().empty());
RangeParts res(map.getNumResults());
RangeParts rangeParts(ranges);
for (auto expr : map.getResults()) {
AffineMap map = AffineMap::get(numDims, 0, expr, {});
res.mins.push_back(makeFoldedComposedAffineApply(map, rangeParts.mins));
res.maxes.push_back(makeFoldedComposedAffineApply(map, rangeParts.maxes));
res.steps.push_back(makeFoldedComposedAffineApply(map, rangeParts.steps));
}
return res;
}
SmallVector<Value *, 4> makeGenericRanges(AffineMap map,
ArrayRef<Value *> ranges) {
return makeGenericRangeParts(map, ranges).makeRanges();
}
static SmallVector<Value *, 4> makeGenericLoopRanges(
AffineMap operandRangesToLoopsMap, ArrayRef<Value *> ranges,
llvm::Optional<ArrayRef<Value *>> tileSizes = llvm::None) {
RangeParts res = makeGenericRangeParts(operandRangesToLoopsMap, ranges);
if (!tileSizes.hasValue())
return res.makeRanges();
SmallVector<Value *, 4> tiledSteps;
for (auto z : llvm::zip(res.steps, *tileSizes)) {
auto *step = std::get<0>(z);
auto tileSize = std::get<1>(z);
auto stepValue = step->getDefiningOp()->cast<ConstantIndexOp>().getValue();
auto tileSizeValue =
tileSize->getDefiningOp()->cast<ConstantIndexOp>().getValue();
assert(stepValue > 0);
tiledSteps.push_back(constant_index(stepValue * tileSizeValue));
}
res.steps = tiledSteps;
return res.makeRanges();
}
template <class ContractionOp>
static SmallVector<mlir::AffineForOp, 4>
writeAsLoops(ContractionOp contraction) {
ScopedContext scope(mlir::FuncBuilder(contraction.getOperation()),
contraction.getLoc());
auto loopRanges = makeGenericLoopRanges(operandRangesToLoopsMap(contraction),
getRanges(contraction));
SmallVector<IndexHandle, 4> parallelIvs(contraction.getNumParallelDims());
SmallVector<IndexHandle, 4> reductionIvs(contraction.getNumReductionDims());
auto pivs = IndexHandle::makeIndexHandlePointers(parallelIvs);
auto rivs = IndexHandle::makeIndexHandlePointers(reductionIvs);
assert(loopRanges.size() == pivs.size() + rivs.size());
// clang-format off
using linalg::common::LoopNestRangeBuilder;
ArrayRef<Value *> ranges(loopRanges);
LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))({
LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))({
[&contraction, &parallelIvs, &reductionIvs]() {
SmallVector<mlir::Value *, 4> parallel(
parallelIvs.begin(), parallelIvs.end());
SmallVector<mlir::Value *, 4> reduction(
reductionIvs.begin(), reductionIvs.end());
contraction.emitScalarImplementation(parallel, reduction);
/// NestedBuilders expect handles, we thus return an IndexHandle.
return IndexHandle();
}()
})
});
// clang-format on
SmallVector<mlir::AffineForOp, 4> res;
res.reserve(pivs.size() + rivs.size());
for (auto iv : parallelIvs)
res.push_back(getForInductionVarOwner(iv.getValue()));
for (auto iv : reductionIvs)
res.push_back(getForInductionVarOwner(iv.getValue()));
return res;
}
void linalg::lowerToLoops(mlir::Function *f) {
f->walkPostOrder([](Operation *op) {
if (auto matmulOp = op->dyn_cast<linalg::MatmulOp>()) {
writeAsLoops(matmulOp);
} else if (auto matvecOp = op->dyn_cast<linalg::MatvecOp>()) {
writeAsLoops(matvecOp);
} else if (auto dotOp = op->dyn_cast<linalg::DotOp>()) {
writeAsLoops(dotOp);
} else {
return;
}
op->erase();
});
}