forked from OSchip/llvm-project
Linalg portion of the tutorial - part 3-2
This CL adds support for lowering tensor contractions to loops declaratively. This is done thanks to two properties of the such operations: 1. the definition of an AffineMap getLoopsToOperandRangesMap for each op which maps iteration space dimensions to ranges of the view operands, in their order of occurrence; 2. the definition of a scalar implementation for each op which creates the computation inside the loops given enclosing parallel and reduction loops, All the other properties are derived in a generic fashion from these 2 properties and a few analyses. A lowerToLoops transformation is added as well as a test that exercises it. -- PiperOrigin-RevId: 241783992
This commit is contained in:
parent
b9e3b2107b
commit
50df91745d
|
@ -41,6 +41,7 @@ protected:
|
|||
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
|
||||
void print(mlir::OpAsmPrinter *p);
|
||||
|
||||
public:
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -48,7 +49,6 @@ protected:
|
|||
mlir::Operation::operand_range getInputs();
|
||||
mlir::Operation::operand_range getOutputs();
|
||||
|
||||
public:
|
||||
/// These are better as methods calling into the ConcreteOp instead of
|
||||
/// template parameters because methods allow more generic behavior and avoid
|
||||
/// specializing for number of arguments. All derived classes have
|
||||
|
@ -72,14 +72,24 @@ public:
|
|||
//////////////////////////////////////////////////////////////////////////////
|
||||
mlir::Value *getInputView(unsigned i);
|
||||
mlir::Value *getOutputView(unsigned i);
|
||||
/// Computes a mapping from all the ranges of the operands to the enclosing
|
||||
/// loops. In order to support "broadcast"-style semantics, we need to
|
||||
/// consider all the operands (i.e. input operands are not sufficient).
|
||||
|
||||
/// Each op is responsible for declaring how it lowers itself to scalar form,
|
||||
/// given the enclosing parallel and reduction induction variables.
|
||||
/// `emitScalarImplementation` emits the scalar IR for the op in the nesting
|
||||
/// context of the innermost enclosing loop(i.e. `reductionIvs.back()` or
|
||||
/// `parallel.back()`).
|
||||
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs);
|
||||
|
||||
/// Represents a mapping from the loops to all the ranges of the operands.
|
||||
/// The operands and their ranges are in the order defined by the particular
|
||||
/// ConcreteOp implementation, the resulting map must match those.
|
||||
/// This is currently computed but can also be specified explicitly in each
|
||||
/// operator to generalize to cases where an analysis is not available.
|
||||
mlir::AffineMap operandRangesToLoopsMap();
|
||||
/// In favorable cases, this can be calculated by an analysis but specifying
|
||||
/// it explicitly is not expensive and generalizes to cases where an analysis
|
||||
/// is not available.
|
||||
/// For details, see the description of loopsToOperandRangesMap in each
|
||||
/// ConcreteOp
|
||||
mlir::AffineMap loopsToOperandRangesMap();
|
||||
};
|
||||
|
||||
/// Implements c = A * B where c is a scalar and A and B are 1-D vectors.
|
||||
|
@ -119,6 +129,27 @@ public:
|
|||
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
|
||||
/// loop over matvec). Does nothing by default.
|
||||
void writeAsFinerGrainTensorContraction();
|
||||
|
||||
/// Inputs to this map will be (%k) coming from enclosing loops.
|
||||
/// Therefore, the mapping to get back to A(K), B(K), C() is:
|
||||
/// (d0) -> (d0, d0)(%k)
|
||||
/// And the operands ranges are:
|
||||
/// (%k, %k)
|
||||
mlir::AffineMap loopsToOperandRangesMap();
|
||||
|
||||
/// Given an enclosing reduction loop with iv `r_i`, emits MLIR corresponding
|
||||
/// to:
|
||||
/// 1. conditionally assign scalarC to 0.0f on the first iteration or load
|
||||
/// C[] from memory (0-D tensor)
|
||||
/// 2. multiply A[r_i] by B[r_i] and add to scalarC
|
||||
/// 3. store back scalarC at C[]
|
||||
///
|
||||
/// In some compact index notation this could be written:
|
||||
/// cond = (r_i == zero)
|
||||
/// scalarC = select(cond, zerof, C[]);
|
||||
/// C[] = scalarC + A[r_i] * B[r_i];
|
||||
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs);
|
||||
};
|
||||
|
||||
/// Implements C = A * B where A is a 2-D matrix and X and Y are 1-D vectors.
|
||||
|
@ -158,6 +189,27 @@ public:
|
|||
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
|
||||
/// loop over matvec). Does nothing by default.
|
||||
void writeAsFinerGrainTensorContraction();
|
||||
|
||||
/// Inputs to this map will be (%m, %k) coming from enclosing loops.
|
||||
/// Therefore, the mapping to get back to A(M, K), B(K), C(M) is:
|
||||
/// (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
|
||||
/// And the operands ranges are:
|
||||
/// (%m, %k, %k, %m)
|
||||
mlir::AffineMap loopsToOperandRangesMap();
|
||||
|
||||
/// Given an enclosing parallel loop with iv `i` and an enclosing parallel
|
||||
/// loop with iv `r_j`, emits MLIR corresponding to:
|
||||
/// 1. conditionally assign scalarC to 0.0f on the first iteration or load
|
||||
/// C[i]
|
||||
/// 2. multiply A[i, r_j] by B[r_j] and add to scalarC
|
||||
/// 3. store back scalarC at C[i]
|
||||
///
|
||||
/// In some compact index notation this could be written:
|
||||
/// cond = (r_j == zero)
|
||||
/// scalarC = select(cond, zerof, C(i));
|
||||
/// C(i) = scalarC + A(i, r_j) * B(r_j);
|
||||
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs);
|
||||
};
|
||||
|
||||
/// Implements C = A * B on 2-D matrices.
|
||||
|
@ -197,6 +249,27 @@ public:
|
|||
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
|
||||
/// loop over matvec). Does nothing by default.
|
||||
void writeAsFinerGrainTensorContraction();
|
||||
|
||||
/// Inputs to this map will be (%m, %n, %k) coming from enclosing loops.
|
||||
/// Therefore, the mapping to get back to A(M, K), B(K, N), C(M, N) is:
|
||||
/// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
|
||||
/// And the operands ranges are:
|
||||
/// (%m, %k, %k, %n, %m, %n)
|
||||
mlir::AffineMap loopsToOperandRangesMap();
|
||||
|
||||
/// Given a enclosing parallel loops with ivs `i` and `j`, and an enclosing
|
||||
/// reduction loop with iv `r_k`, emits MLIR corresponding to:
|
||||
/// 1. conditionally assign scalarC to 0.0f on the first iteration or load
|
||||
/// C[i, j]
|
||||
/// 2. multiply A[i, r_k] by B[r_k, j] and add to scalarC
|
||||
/// 3. store back scalarC at C[i, j]
|
||||
///
|
||||
/// In some compact index notation this could be written:
|
||||
/// cond = (r_k == zero)
|
||||
/// scalarC = select(cond, zerof, C[i, j]);
|
||||
/// C[i, j] = scalarC + A[i, r_k] * B[r_k, j];
|
||||
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs);
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
|
|
@ -64,13 +64,15 @@ TEST_FUNC(matmul_as_matvec) {
|
|||
Module module(&context);
|
||||
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec");
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
composeSliceOps(f);
|
||||
// clang-format off
|
||||
// CHECK-LABEL: func @matmul_as_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
|
||||
// CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
|
||||
// CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view<f32xf32>">
|
||||
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
|
||||
// CHECK-NEXT: %[[vB:.*]] = linalg.slice %{{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
|
||||
// CHECK-NEXT: %[[vC:.*]] = linalg.slice %{{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
|
||||
// CHECK-NEXT: linalg.matvec {%{{.*}}, %[[vB]]} -> {%[[vC]]}
|
||||
// CHECK: %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
|
||||
// CHECK: %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
|
||||
// CHECK: linalg.matvec {%[[vA]], %[[vB]]} -> {%[[vC]]}
|
||||
// clang-format on
|
||||
cleanupAndPrintFunction(f);
|
||||
}
|
||||
|
@ -81,21 +83,84 @@ TEST_FUNC(matmul_as_dot) {
|
|||
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_dot");
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
composeSliceOps(f);
|
||||
// clang-format off
|
||||
// CHECK-LABEL: func @matmul_as_dot(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
|
||||
// CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
|
||||
// CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
|
||||
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
|
||||
// CHECK-NEXT: %[[vB:.*]] = linalg.slice {{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
|
||||
// CHECK-NEXT: %[[sC:.*]] = linalg.slice {{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
|
||||
// CHECK: %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
|
||||
// CHECK-NEXT: affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) {
|
||||
// CHECK-NEXT: %[[vA:.*]] = linalg.slice {{.*}}[%i1, *] { dim : 0 } : !linalg<"view<f32>">
|
||||
// CHECK-NEXT: %[[vC:.*]] = linalg.slice %[[sC]][%i1] { dim : 0 } : !linalg<"view<0xf32>">
|
||||
// CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
|
||||
// CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view<0xf32>">
|
||||
// CHECK-NEXT: linalg.dot {%[[vA]], %[[vB]]} -> {%[[vC]]}
|
||||
// clang-format on
|
||||
cleanupAndPrintFunction(f);
|
||||
}
|
||||
|
||||
TEST_FUNC(matmul_as_loops) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
|
||||
lowerToLoops(f);
|
||||
composeSliceOps(f);
|
||||
// clang-format off
|
||||
// CHECK-LABEL: func @matmul_as_loops(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
|
||||
// CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
|
||||
// CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
|
||||
// CHECK: %[[K:.*]] = dim %arg0, 1 : memref<?x?xf32>
|
||||
// CHECK: %[[rM:.*]] = linalg.range %c0:%[[M]]:%c1 : !linalg<"range">
|
||||
// CHECK: %[[rN:.*]] = linalg.range %c0:%[[N]]:%c1 : !linalg<"range">
|
||||
// CHECK: %[[rK:.*]] = linalg.range %c0:%[[K]]:%c1 : !linalg<"range">
|
||||
// CHECK: %[[vA:.*]] = linalg.view %arg0[%[[rM]], %[[rK]]] : !linalg<"view<f32xf32>">
|
||||
// CHECK: %[[vB:.*]] = linalg.view %arg1[%[[rK]], %[[rN]]] : !linalg<"view<f32xf32>">
|
||||
// CHECK: %[[vC:.*]] = linalg.view %arg2[%[[rM]], %[[rN]]] : !linalg<"view<f32xf32>">
|
||||
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) {
|
||||
// CHECK: affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) {
|
||||
// CHECK: affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) {
|
||||
// CHECK: %{{.*}} = cmpi "eq", %{{.*}} : index
|
||||
// CHECK: %{{.*}} = linalg.load %[[vC]][%i0, %i1] : !linalg<"view<f32xf32>">
|
||||
// CHECK: %{{.*}} = select {{.*}} : f32
|
||||
// CHECK: %{{.*}} = linalg.load %[[vB]][%i2, %i1] : !linalg<"view<f32xf32>">
|
||||
// CHECK: %{{.*}} = linalg.load %[[vA]][%i0, %i2] : !linalg<"view<f32xf32>">
|
||||
// CHECK: %{{.*}} = mulf {{.*}} : f32
|
||||
// CHECK: %{{.*}} = addf {{.*}} : f32
|
||||
// CHECK: linalg.store {{.*}}[%i0, %i1] : !linalg<"view<f32xf32>">
|
||||
// clang-format on
|
||||
cleanupAndPrintFunction(f);
|
||||
}
|
||||
|
||||
TEST_FUNC(matmul_as_matvec_as_loops) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
mlir::Function *f =
|
||||
makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_loops");
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
lowerToLoops(f);
|
||||
composeSliceOps(f);
|
||||
// clang-format off
|
||||
// CHECK-LABEL: func @matmul_as_matvec_as_loops(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
|
||||
// CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
|
||||
// CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
|
||||
// CHECK: %[[K:.*]] = dim %arg0, 1 : memref<?x?xf32>
|
||||
// CHECK: %[[vA:.*]] = linalg.view %arg0[{{.*}}, {{.*}}] : !linalg<"view<f32xf32>">
|
||||
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
|
||||
// CHECK: %[[vB:.*]] = linalg.view %arg1[{{.*}}, {{.*}}] : !linalg<"view<f32>">
|
||||
// CHECK: %[[vC:.*]] = linalg.view %arg2[{{.*}}, {{.*}}] : !linalg<"view<f32>">
|
||||
// CHECK: affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) {
|
||||
// CHECK: affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) {
|
||||
// CHECK: %{{.*}} = cmpi "eq", %i2, %{{.*}} : index
|
||||
// CHECK: %[[C:.*]] = linalg.load %[[vC]][%i1] : !linalg<"view<f32>">
|
||||
// CHECK: %[[C2:.*]] = select %{{.*}}, %{{.*}}, %[[C]] : f32
|
||||
// CHECK: %[[B:.*]] = linalg.load %[[vB]][%i2] : !linalg<"view<f32>">
|
||||
// CHECK: %[[A:.*]] = linalg.load %[[vA]][%i1, %i2] : !linalg<"view<f32xf32>">
|
||||
// CHECK: %{{.*}} = mulf %[[A]], %[[B]] : f32
|
||||
// CHECK: %{{.*}} = addf %[[C2]], %{{.*}} : f32
|
||||
// CHECK: linalg.store %{{.*}}, %{{.*}}[%i1] : !linalg<"view<f32>">
|
||||
// clang-format on
|
||||
cleanupAndPrintFunction(f);
|
||||
}
|
||||
|
||||
int main() {
|
||||
RUN_TESTS();
|
||||
return 0;
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
//===- Analysis.h - Linalg dialect Analysis function definitions ----------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG3_ANALYSIS_H_
|
||||
#define LINALG3_ANALYSIS_H_
|
||||
|
||||
#include "linalg2/Analysis.h"
|
||||
|
||||
namespace mlir {
|
||||
class AffineMap;
|
||||
} // namespace mlir
|
||||
|
||||
namespace linalg {
|
||||
|
||||
/// Given a `map` specification and a subset of its results
|
||||
/// `[beginResult, endResult)`, returns the inverse map that maps result
|
||||
/// positions to dim positions.
|
||||
mlir::AffineMap inverseSubMap(mlir::AffineMap map, unsigned beginResult = 0,
|
||||
unsigned endResult = 0);
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG3_ANALYSIS_H_
|
|
@ -0,0 +1,31 @@
|
|||
//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG3_INTRINSICS_H_
|
||||
#define LINALG3_INTRINSICS_H_
|
||||
|
||||
#include "linalg2/Intrinsics.h"
|
||||
#include "linalg3/Ops.h"
|
||||
|
||||
namespace linalg {
|
||||
namespace intrinsics {
|
||||
using load = mlir::edsc::intrinsics::ValueBuilder<LoadOp>;
|
||||
using store = mlir::edsc::intrinsics::OperationBuilder<StoreOp>;
|
||||
} // namespace intrinsics
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG3_INTRINSICS_H_
|
|
@ -0,0 +1,89 @@
|
|||
//===- LoadStoreOps.h - Linalg dialect Load/Store operation definitions ---===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG3_LOADSTOREOP_H_
|
||||
#define LINALG3_LOADSTOREOP_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
class ViewType;
|
||||
|
||||
/// A linalg.LoadOp is the counterpart of affine.load but operating on ViewType
|
||||
/// instead of MemRefType.
|
||||
class LoadOp : public mlir::Op<LoadOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::OneResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
static llvm::StringRef getOperationName() { return "linalg.load"; }
|
||||
static void build(mlir::Builder *b, mlir::OperationState *result,
|
||||
mlir::Value *view,
|
||||
mlir::ArrayRef<mlir::Value *> indices = {});
|
||||
mlir::LogicalResult verify();
|
||||
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
|
||||
void print(mlir::OpAsmPrinter *p);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
unsigned getRank();
|
||||
ViewType getViewType();
|
||||
mlir::Value *getView() { return getOperand(0); }
|
||||
mlir::Operation::operand_range getIndices() {
|
||||
return {operand_begin() + 1, operand_end()};
|
||||
}
|
||||
};
|
||||
|
||||
/// A linalg.StoreOp is the counterpart of affine.store but operating on
|
||||
/// ViewType instead of MemRefType.
|
||||
class StoreOp : public mlir::Op<StoreOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::ZeroResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
static llvm::StringRef getOperationName() { return "linalg.store"; }
|
||||
static void build(mlir::Builder *b, mlir::OperationState *result,
|
||||
mlir::Value *valueToStore, mlir::Value *view,
|
||||
mlir::ArrayRef<mlir::Value *> indices = {});
|
||||
mlir::LogicalResult verify();
|
||||
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
|
||||
void print(mlir::OpAsmPrinter *p);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
unsigned getRank();
|
||||
ViewType getViewType();
|
||||
mlir::Value *getValueToStore() { return getOperand(0); }
|
||||
mlir::Value *getView() { return getOperand(1); }
|
||||
mlir::Operation::operand_range getIndices() {
|
||||
return {operand_begin() + 2, operand_end()};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG3_LOADSTOREOP_H_
|
|
@ -19,6 +19,7 @@
|
|||
#define LINALG3_OPS_H_
|
||||
|
||||
#include "linalg2/Ops.h"
|
||||
#include "linalg3/LoadStoreOps.h"
|
||||
#include "linalg3/TensorOps.h"
|
||||
|
||||
#endif // LINALG3_OPS_H_
|
||||
|
|
|
@ -22,9 +22,10 @@
|
|||
#define LINALG3_TENSOROPS_INL_H_
|
||||
|
||||
#include "linalg1/Common.h"
|
||||
#include "linalg1/Utils.h"
|
||||
#include "linalg2/TensorOps.h"
|
||||
|
||||
namespace linalg {
|
||||
#include "linalg3/Analysis.h"
|
||||
#include "linalg3/Ops.h"
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::Value *
|
||||
|
@ -38,6 +39,90 @@ linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned i) {
|
|||
return *(getOutputs().begin() + i);
|
||||
}
|
||||
|
||||
} // namespace linalg
|
||||
template <class ConcreteOp>
|
||||
mlir::AffineMap
|
||||
linalg::TensorContractionBase<ConcreteOp>::loopsToOperandRangesMap() {
|
||||
return static_cast<ConcreteOp *>(this)->loopsToOperandRangesMap();
|
||||
}
|
||||
|
||||
#endif // LINALG3_TENSOROPS-INL_H_
|
||||
template <class ConcreteOp>
|
||||
void linalg::TensorContractionBase<ConcreteOp>::emitScalarImplementation(
|
||||
llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs) {
|
||||
static_cast<ConcreteOp *>(this)->emitScalarImplementation(parallelIvs,
|
||||
reductionIvs);
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::AffineMap linalg::operandRangesToLoopsMap(
|
||||
linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
|
||||
return inverseSubMap(tensorContraction.loopsToOperandRangesMap());
|
||||
}
|
||||
|
||||
// Extract the ranges from a given ViewOp or SliceOp.
|
||||
//
|
||||
// In the case of a ViewOp, things are simple: just traverse the indexings and
|
||||
// get all the ranges (i.e. drop the indices).
|
||||
//
|
||||
// In the case of a SliceOp, things are trickier because we need to handle a
|
||||
// potential rank-reduction:
|
||||
// 1. Examine the indexing to determine if it is rank-reducing.
|
||||
// 2. If it is rank-reducing, an offset of 1 is added to the dimensions such
|
||||
// that `d >= slicingDim`. This is to account for the rank reduction.
|
||||
// `getRootIndex` is then called on the **parent** view
|
||||
static llvm::SmallVector<mlir::Value *, 8>
|
||||
extractRangesFromViewOrSliceOp(mlir::Value *view) {
|
||||
// This expects a viewType which must come from either ViewOp or SliceOp.
|
||||
assert(view->getType().isa<linalg::ViewType>() && "expected ViewType");
|
||||
if (auto viewOp = view->getDefiningOp()->dyn_cast<linalg::ViewOp>())
|
||||
return viewOp.getRanges();
|
||||
|
||||
auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>();
|
||||
unsigned slicingDim = sliceOp.getSlicingDim();
|
||||
auto *indexing = *(sliceOp.getIndexings().begin());
|
||||
bool isRankReducing = indexing->getType().isa<mlir::IndexType>();
|
||||
unsigned offset = 0;
|
||||
llvm::SmallVector<mlir::Value *, 8> res;
|
||||
res.reserve(sliceOp.getRank());
|
||||
for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) {
|
||||
if (d == slicingDim && isRankReducing)
|
||||
offset = 1;
|
||||
auto *parentView = sliceOp.getParentView();
|
||||
auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset);
|
||||
res.push_back(indexingPosPair.first);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
static llvm::SmallVector<mlir::Value *, 8>
|
||||
getInputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
|
||||
llvm::SmallVector<mlir::Value *, 8> res;
|
||||
for (auto *in : tensorContraction.getInputs()) {
|
||||
auto subres = extractRangesFromViewOrSliceOp(in);
|
||||
res.append(subres.begin(), subres.end());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
static llvm::SmallVector<mlir::Value *, 8>
|
||||
getOutputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
|
||||
llvm::SmallVector<mlir::Value *, 8> res;
|
||||
for (auto *out : tensorContraction.getOutputs()) {
|
||||
auto subres = extractRangesFromViewOrSliceOp(out);
|
||||
res.append(subres.begin(), subres.end());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
llvm::SmallVector<mlir::Value *, 8> linalg::getRanges(
|
||||
linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
|
||||
llvm::SmallVector<mlir::Value *, 8> res = getInputRanges(tensorContraction);
|
||||
llvm::SmallVector<mlir::Value *, 8> tmp = getOutputRanges(tensorContraction);
|
||||
res.append(tmp.begin(), tmp.end());
|
||||
return res;
|
||||
}
|
||||
|
||||
#endif // LINALG3_TENSOROPS_INL_H_
|
||||
|
|
|
@ -20,6 +20,32 @@
|
|||
|
||||
#include "linalg2/TensorOps.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
///
|
||||
/// Ideally all these functions would go in an Analysis but until
|
||||
/// TensorContractionBase is templated, they need to remain close enough.
|
||||
///
|
||||
|
||||
/// Takes a `tensorContraction` and a returns an AffineMap that can be used to
|
||||
/// map ranges to enclosing loops for all the operands' ranges.
|
||||
template <class ConcreteOp>
|
||||
mlir::AffineMap operandRangesToLoopsMap(
|
||||
linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
|
||||
|
||||
/// Takes a `tensorContraction` and returns the ranges of all its operands.
|
||||
/// When an operand comes from a ViewOp, things are simple:
|
||||
/// just traverse the indexings and get all the ranges
|
||||
/// (i.e. drop the rank-reducing indices).
|
||||
/// In the case of a SliceOp, things are more involved because we need to handle
|
||||
/// potential rank-reductions.
|
||||
/// This function abstracts this complexity away and returns all the ranges.
|
||||
template <class ConcreteOp>
|
||||
llvm::SmallVector<mlir::Value *, 8>
|
||||
getRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
|
||||
/// TensorOps by adding implementations as they are needed in the appropriate
|
||||
/// step in the tutorial.
|
||||
|
|
|
@ -30,10 +30,17 @@ namespace linalg {
|
|||
/// to only use linalg.view operations.
|
||||
void composeSliceOps(mlir::Function *f);
|
||||
|
||||
/// Traverses `f` and rewrites linalg.matmul (resp. linalg.matvec, linalg.dot)
|
||||
/// as linalg.matvec (resp. linalg.dot, loop form).
|
||||
/// Traverses `f` and rewrites linalg.load and linalg.store to affine.load and
|
||||
/// affine.store operations.
|
||||
void lowerLinalgLoadStores(mlir::Function *f);
|
||||
|
||||
/// Traverses `f` and rewrites linalg.matmul (resp. linalg.matvec)
|
||||
/// as linalg.matvec (resp. linalg.dot).
|
||||
void lowerToFinerGrainedTensorContraction(mlir::Function *f);
|
||||
|
||||
/// Traverses `f` and rewrites linalg operations in loop form.
|
||||
void lowerToLoops(mlir::Function *f);
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG3_TRANSFORMS_H_
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
//===- Analysis.cpp - Implementation of analysis functions for Linalg -----===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple IR operation to create a new RangeType in the
|
||||
// linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg3/Analysis.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
using llvm::SmallVector;
|
||||
using namespace mlir;
|
||||
|
||||
// Compute an inverse map (only works with permutations for now).
|
||||
// Note that the mapping is generally non-full rank, so this returns the first
|
||||
// seen entry for each dim.
|
||||
static AffineMap inversePermutationMap(AffineMap map) {
|
||||
SmallVector<AffineExpr, 4> exprs(map.getNumDims());
|
||||
for (auto en : llvm::enumerate(map.getResults())) {
|
||||
auto expr = en.value();
|
||||
auto d = expr.dyn_cast<AffineDimExpr>();
|
||||
assert(d && "permutation map expected");
|
||||
if (exprs[d.getPosition()])
|
||||
continue;
|
||||
exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext());
|
||||
}
|
||||
SmallVector<AffineExpr, 4> seenExprs;
|
||||
seenExprs.reserve(map.getNumDims());
|
||||
for (auto expr : exprs)
|
||||
if (expr)
|
||||
seenExprs.push_back(expr);
|
||||
assert(map.getNumSymbols() == 0 && "expected map without symbols");
|
||||
assert(seenExprs.size() == map.getNumInputs() && "map is not invertible");
|
||||
return AffineMap::get(map.getNumResults(), 0, seenExprs, {});
|
||||
}
|
||||
|
||||
mlir::AffineMap linalg::inverseSubMap(AffineMap map, unsigned beginResult,
|
||||
unsigned endResult) {
|
||||
if (beginResult == 0 && endResult == 0)
|
||||
endResult = map.getNumResults();
|
||||
auto subMap = AffineMap::get(
|
||||
map.getNumDims(), map.getNumSymbols(),
|
||||
map.getResults().slice(beginResult, endResult - beginResult), {});
|
||||
return inversePermutationMap(subMap);
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
//===- DialectRegistration.cpp - Registration of the Linalg dialect -------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file registers the Linalg dialect and should live in a standalone
|
||||
// library. Linking with this library will create a static global object that
|
||||
// performs dialect registration.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg1/Dialect.h"
|
||||
#include "linalg1/Types.h"
|
||||
#include "linalg3/Ops.h"
|
||||
|
||||
using namespace linalg;
|
||||
|
||||
LinalgDialect::LinalgDialect(mlir::MLIRContext *context)
|
||||
: Dialect("linalg", context) {
|
||||
addTypes<RangeType, ViewType>();
|
||||
addOperations<DotOp, LoadOp, MatvecOp, MatmulOp, RangeOp, SliceOp, StoreOp,
|
||||
ViewOp>();
|
||||
}
|
||||
|
||||
// Dialect registration triggers the creation of a `LinalgDialect` object which
|
||||
// adds the proper types and operations to the dialect.
|
||||
static mlir::DialectRegistration<LinalgDialect> LinalgOps;
|
|
@ -0,0 +1,136 @@
|
|||
//===- LoadStoreOps.cpp - Implementation of linalg Load/Store operations --===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements linalg.load and linalg.store operations which allow
|
||||
// accessing memory through ViewType values.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg3/LoadStoreOps.h"
|
||||
#include "linalg3/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using namespace mlir;
|
||||
using namespace linalg;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// LoadOp.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
void linalg::LoadOp::build(Builder *b, OperationState *result, Value *view,
|
||||
ArrayRef<Value *> indices) {
|
||||
auto viewType = view->getType().cast<ViewType>();
|
||||
result->addOperands(view);
|
||||
result->addOperands(indices);
|
||||
result->addTypes(viewType.getElementType());
|
||||
}
|
||||
|
||||
void linalg::LoadOp::print(OpAsmPrinter *p) {
|
||||
*p << getOperationName() << " " << *getView() << '[';
|
||||
p->printOperands(getIndices());
|
||||
*p << ']';
|
||||
p->printOptionalAttrDict(getAttrs());
|
||||
*p << " : " << getViewType();
|
||||
}
|
||||
|
||||
bool linalg::LoadOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult linalg::LoadOp::verify() {
|
||||
if (getNumOperands() == 0)
|
||||
return emitOpError("expected a view to load from");
|
||||
|
||||
auto viewType = getView()->getType().dyn_cast<ViewType>();
|
||||
if (!viewType)
|
||||
return emitOpError("first operand must be a view");
|
||||
|
||||
if (getType() != viewType.getElementType())
|
||||
return emitOpError("result type must match element type of the view");
|
||||
|
||||
if (getRank() != getNumOperands() - 1)
|
||||
return emitOpError("incorrect number of indices for load");
|
||||
|
||||
for (auto *idx : getIndices())
|
||||
if (!idx->getType().isIndex())
|
||||
return emitOpError("index to load must have 'index' type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
ViewType linalg::LoadOp::getViewType() {
|
||||
return getView()->getType().cast<ViewType>();
|
||||
}
|
||||
|
||||
unsigned linalg::LoadOp::getRank() { return getViewType().getRank(); }
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// StoreOp.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
void linalg::StoreOp::build(Builder *b, OperationState *result,
|
||||
Value *valueToStore, Value *view,
|
||||
ArrayRef<Value *> indices) {
|
||||
result->addOperands(valueToStore);
|
||||
result->addOperands(view);
|
||||
result->addOperands(indices);
|
||||
}
|
||||
|
||||
void linalg::StoreOp::print(OpAsmPrinter *p) {
|
||||
*p << getOperationName() << " " << *getValueToStore();
|
||||
*p << ", " << *getView() << '[';
|
||||
p->printOperands(getIndices());
|
||||
*p << ']';
|
||||
p->printOptionalAttrDict(getAttrs());
|
||||
*p << " : " << getViewType();
|
||||
}
|
||||
|
||||
bool linalg::StoreOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
assert(false && "NYI");
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult linalg::StoreOp::verify() {
|
||||
if (getNumOperands() < 2)
|
||||
return emitOpError("expected a value to store and a view");
|
||||
|
||||
// Second operand is a memref type.
|
||||
auto viewType = getView()->getType().dyn_cast<ViewType>();
|
||||
if (!viewType)
|
||||
return emitOpError("second operand must be a view");
|
||||
|
||||
// First operand must have same type as memref element type.
|
||||
if (getValueToStore()->getType() != viewType.getElementType())
|
||||
return emitOpError("first operand must have same element type as the view");
|
||||
|
||||
if (getNumOperands() != 2 + viewType.getRank())
|
||||
return emitOpError("store index operand count not equal to view rank");
|
||||
|
||||
for (auto *idx : getIndices())
|
||||
if (!idx->getType().isIndex())
|
||||
return emitOpError("index to store must have 'index' type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
unsigned linalg::StoreOp::getRank() { return getViewType().getRank(); }
|
||||
|
||||
ViewType linalg::StoreOp::getViewType() {
|
||||
return getView()->getType().cast<ViewType>();
|
||||
}
|
|
@ -22,7 +22,7 @@
|
|||
|
||||
#include "linalg1/Analysis.h"
|
||||
#include "linalg1/Common.h"
|
||||
#include "linalg2/Intrinsics.h"
|
||||
#include "linalg3/Intrinsics.h"
|
||||
#include "linalg3/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
@ -32,23 +32,121 @@
|
|||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace linalg;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Implementation of DotOp.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
AffineMap linalg::DotOp::loopsToOperandRangesMap() {
|
||||
// A(K), B(K), C()
|
||||
assert(getRanges(*this).size() == 2);
|
||||
auto *context = ScopedContext::getContext();
|
||||
auto d0 = getAffineDimExpr(0, context); // K
|
||||
// A(K), B(K), C()
|
||||
// (d0) -> (d0, d0)(%k)
|
||||
return AffineMap::get(1, 0, {d0, d0}, {});
|
||||
}
|
||||
|
||||
void linalg::DotOp::emitScalarImplementation(
|
||||
llvm::ArrayRef<Value *> parallelIvs, llvm::ArrayRef<Value *> reductionIvs) {
|
||||
using IndexedValue = TemplatedIndexedValue<linalg::intrinsics::load,
|
||||
linalg::intrinsics::store>;
|
||||
assert(reductionIvs.size() == 1);
|
||||
auto innermostLoop = getForInductionVarOwner(reductionIvs.back());
|
||||
auto *body = innermostLoop.getBody();
|
||||
using edsc::op::operator+;
|
||||
using edsc::op::operator*;
|
||||
using edsc::op::operator==;
|
||||
using edsc::intrinsics::select;
|
||||
ScopedContext scope( // account for affine.terminator in loop.
|
||||
FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
|
||||
auto f32 = ScopedContext::getBuilder()->getF32Type();
|
||||
IndexHandle zero(constant_index(0));
|
||||
ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
|
||||
IndexHandle r_i(reductionIvs[0]);
|
||||
IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
|
||||
ValueHandle cond = (r_i == zero);
|
||||
ValueHandle scalarC = select(cond, zerof, *C());
|
||||
C() = scalarC + A(r_i) * B(r_i);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Implementation of MatvecOp.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
AffineMap linalg::MatvecOp::loopsToOperandRangesMap() {
|
||||
// A(M, K), B(K), C(M)
|
||||
assert(getRanges(*this).size() == 4);
|
||||
auto *context = ScopedContext::getContext();
|
||||
auto d0 = getAffineDimExpr(0, context); // M
|
||||
auto d1 = getAffineDimExpr(1, context); // K
|
||||
// A(M, K), B(K), C(M)
|
||||
// (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
|
||||
return AffineMap::get(2, 0, {d0, d1, d1, d0}, {});
|
||||
}
|
||||
|
||||
// The body expression for matvec is: C(i) = scalarC + A(i, r_j) * B(r_j)
|
||||
// The body expression for dot is: C() = A(r_i) * B(r_i);
|
||||
// So we must drop the `i` loop from the matvec.
|
||||
void linalg::MatvecOp::writeAsFinerGrainTensorContraction() {
|
||||
auto *op = getOperation();
|
||||
ScopedContext scope(FuncBuilder(op), op->getLoc());
|
||||
IndexHandle i;
|
||||
auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0));
|
||||
auto indexingPosPair = getViewRootIndexing(vA, 0);
|
||||
assert(indexingPosPair.first->getDefiningOp() &&
|
||||
indexingPosPair.first->getDefiningOp()->isa<RangeOp>());
|
||||
linalg::common::LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))({
|
||||
dot(slice(vA, i, 0), vB, slice(vC, i, 0)),
|
||||
// clang-format off
|
||||
ScopedContext scope(FuncBuilder(op), op->getLoc());
|
||||
IndexHandle i;
|
||||
using linalg::common::LoopNestRangeBuilder;
|
||||
LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))({
|
||||
[&i, &vA, &vB, &vC]() {
|
||||
ValueHandle sliceA = slice(vA, i, 0);
|
||||
ValueHandle sliceC = slice(vC, i, 0);
|
||||
dot(sliceA, vB, sliceC);
|
||||
/// NestedBuilders expect handles, we thus return an IndexHandle.
|
||||
return IndexHandle();
|
||||
}()
|
||||
});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void linalg::MatvecOp::emitScalarImplementation(
|
||||
llvm::ArrayRef<Value *> parallelIvs, llvm::ArrayRef<Value *> reductionIvs) {
|
||||
using IndexedValue = TemplatedIndexedValue<linalg::intrinsics::load,
|
||||
linalg::intrinsics::store>;
|
||||
assert(reductionIvs.size() == 1);
|
||||
auto innermostLoop = getForInductionVarOwner(reductionIvs.back());
|
||||
auto *body = innermostLoop.getBody();
|
||||
using edsc::op::operator+;
|
||||
using edsc::op::operator*;
|
||||
using edsc::op::operator==;
|
||||
using edsc::intrinsics::select;
|
||||
ScopedContext scope( // account for affine.terminator in loop.
|
||||
FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
|
||||
auto f32 = ScopedContext::getBuilder()->getF32Type();
|
||||
IndexHandle i(parallelIvs[0]), r_j(reductionIvs[0]);
|
||||
IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
|
||||
IndexHandle zero(constant_index(0));
|
||||
ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
|
||||
ValueHandle cond = (r_j == zero);
|
||||
ValueHandle scalarC = select(cond, zerof, *C(i));
|
||||
C(i) = scalarC + A(i, r_j) * B(r_j);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific Matmul.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
AffineMap linalg::MatmulOp::loopsToOperandRangesMap() {
|
||||
// A(M, K), B(K, N), C(M, N)
|
||||
assert(getRanges(*this).size() == 6);
|
||||
auto *context = ScopedContext::getContext();
|
||||
auto d0 = getAffineDimExpr(0, context); // M
|
||||
auto d1 = getAffineDimExpr(1, context); // N
|
||||
auto d2 = getAffineDimExpr(2, context); // K
|
||||
// A(M, K), B(K, N), C(M, N):
|
||||
// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
|
||||
return AffineMap::get(3, 0, {d0, d2, d2, d1, d0, d1}, {});
|
||||
}
|
||||
|
||||
// The body expression for matmul is: C(i, j) = scalarC + A(i, r_k) * B(r_k, j)
|
||||
|
@ -58,13 +156,45 @@ void linalg::MatvecOp::writeAsFinerGrainTensorContraction() {
|
|||
// declaratively.
|
||||
void linalg::MatmulOp::writeAsFinerGrainTensorContraction() {
|
||||
auto *op = getOperation();
|
||||
ScopedContext scope(FuncBuilder(op), op->getLoc());
|
||||
IndexHandle j;
|
||||
auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0));
|
||||
auto indexingPosPair = getViewRootIndexing(vB, 1);
|
||||
assert(indexingPosPair.first->getDefiningOp() &&
|
||||
indexingPosPair.first->getDefiningOp()->isa<RangeOp>());
|
||||
linalg::common::LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))({
|
||||
matvec(vA, slice(vB, j, 1), slice(vC, j, 1)),
|
||||
using linalg::common::LoopNestRangeBuilder;
|
||||
// clang-format off
|
||||
ScopedContext scope(FuncBuilder(op), op->getLoc());
|
||||
IndexHandle j;
|
||||
LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))({
|
||||
[&j, &vA, &vB, &vC]() {
|
||||
ValueHandle sliceB = slice(vB, j, 1);
|
||||
ValueHandle sliceC = slice(vC, j, 1);
|
||||
matvec(vA, sliceB, sliceC);
|
||||
/// NestedBuilders expect handles, we thus return an IndexHandle.
|
||||
return IndexHandle();
|
||||
}()
|
||||
});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void linalg::MatmulOp::emitScalarImplementation(
|
||||
llvm::ArrayRef<Value *> parallelIvs, llvm::ArrayRef<Value *> reductionIvs) {
|
||||
using IndexedValue = TemplatedIndexedValue<linalg::intrinsics::load,
|
||||
linalg::intrinsics::store>;
|
||||
assert(reductionIvs.size() == 1);
|
||||
auto innermostLoop = getForInductionVarOwner(reductionIvs.back());
|
||||
auto *body = innermostLoop.getBody();
|
||||
using edsc::op::operator+;
|
||||
using edsc::op::operator*;
|
||||
using edsc::op::operator==;
|
||||
using edsc::intrinsics::select;
|
||||
ScopedContext scope( // account for affine.terminator in loop.
|
||||
FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
|
||||
auto f32 = ScopedContext::getBuilder()->getF32Type();
|
||||
IndexHandle i(parallelIvs[0]), j(parallelIvs[1]), r_k(reductionIvs[0]);
|
||||
IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
|
||||
IndexHandle zero(constant_index(0));
|
||||
ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
|
||||
ValueHandle cond = r_k == zero;
|
||||
ValueHandle scalarC = select(cond, zerof, *C(i, j));
|
||||
C(i, j) = scalarC + A(i, r_k) * B(r_k, j);
|
||||
}
|
||||
|
|
|
@ -53,3 +53,171 @@ void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) {
|
|||
op->erase();
|
||||
});
|
||||
}
|
||||
|
||||
// Folding eagerly is necessary to abide by affine.for static step requirement.
|
||||
// Returns nullptr if folding is not trivially feasible.
|
||||
static Value *tryFold(AffineMap map, SmallVector<Value *, 4> operands) {
|
||||
assert(map.getNumResults() == 1 && "single result map expected");
|
||||
auto expr = map.getResult(0);
|
||||
if (auto dim = expr.dyn_cast<AffineDimExpr>())
|
||||
return operands[dim.getPosition()];
|
||||
if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
|
||||
return operands[map.getNumDims() + sym.getPosition()];
|
||||
if (auto cst = expr.dyn_cast<AffineConstantExpr>())
|
||||
return constant_index(cst.getValue());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static Value *makeFoldedComposedAffineApply(AffineMap map,
|
||||
ArrayRef<Value *> operandsRef) {
|
||||
SmallVector<Value *, 4> operands(operandsRef.begin(), operandsRef.end());
|
||||
fullyComposeAffineMapAndOperands(&map, &operands);
|
||||
if (auto *v = tryFold(map, operands)) {
|
||||
return v;
|
||||
}
|
||||
auto *b = ScopedContext::getBuilder();
|
||||
auto loc = ScopedContext::getLocation();
|
||||
return b->create<AffineApplyOp>(loc, map, operands).getResult();
|
||||
}
|
||||
|
||||
struct RangeParts {
|
||||
explicit RangeParts(unsigned reserved);
|
||||
RangeParts(ArrayRef<Value *> ranges);
|
||||
|
||||
SmallVector<Value *, 4> makeRanges();
|
||||
|
||||
SmallVector<Value *, 4> mins;
|
||||
SmallVector<Value *, 4> maxes;
|
||||
SmallVector<Value *, 4> steps;
|
||||
};
|
||||
|
||||
RangeParts::RangeParts(unsigned reserved) {
|
||||
mins.reserve(reserved);
|
||||
maxes.reserve(reserved);
|
||||
steps.reserve(reserved);
|
||||
}
|
||||
|
||||
static SmallVector<Value *, 4>
|
||||
extractFromRanges(ArrayRef<Value *> ranges,
|
||||
std::function<Value *(RangeOp)> extract) {
|
||||
SmallVector<Value *, 4> res;
|
||||
res.reserve(ranges.size());
|
||||
for (auto *v : ranges) {
|
||||
auto r = v->getDefiningOp()->cast<RangeOp>();
|
||||
res.push_back(extract(r));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
RangeParts::RangeParts(ArrayRef<Value *> ranges)
|
||||
: mins(extractFromRanges(ranges, [](RangeOp r) { return r.getMin(); })),
|
||||
maxes(extractFromRanges(ranges, [](RangeOp r) { return r.getMax(); })),
|
||||
steps(extractFromRanges(ranges, [](RangeOp r) { return r.getStep(); })) {}
|
||||
|
||||
SmallVector<Value *, 4> RangeParts::makeRanges() {
|
||||
SmallVector<Value *, 4> res;
|
||||
res.reserve(mins.size());
|
||||
for (auto z : llvm::zip(mins, maxes, steps)) {
|
||||
res.push_back(range(std::get<0>(z), std::get<1>(z), std::get<2>(z)));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
static RangeParts makeGenericRangeParts(AffineMap map,
|
||||
ArrayRef<Value *> ranges) {
|
||||
assert(map.getNumInputs() == ranges.size());
|
||||
unsigned numDims = map.getNumDims();
|
||||
assert(map.getNumSymbols() == 0);
|
||||
assert(map.getRangeSizes().empty());
|
||||
|
||||
RangeParts res(map.getNumResults());
|
||||
RangeParts rangeParts(ranges);
|
||||
for (auto expr : map.getResults()) {
|
||||
AffineMap map = AffineMap::get(numDims, 0, expr, {});
|
||||
res.mins.push_back(makeFoldedComposedAffineApply(map, rangeParts.mins));
|
||||
res.maxes.push_back(makeFoldedComposedAffineApply(map, rangeParts.maxes));
|
||||
res.steps.push_back(makeFoldedComposedAffineApply(map, rangeParts.steps));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4> makeGenericRanges(AffineMap map,
|
||||
ArrayRef<Value *> ranges) {
|
||||
return makeGenericRangeParts(map, ranges).makeRanges();
|
||||
}
|
||||
|
||||
static SmallVector<Value *, 4> makeGenericLoopRanges(
|
||||
AffineMap operandRangesToLoopsMap, ArrayRef<Value *> ranges,
|
||||
llvm::Optional<ArrayRef<Value *>> tileSizes = llvm::None) {
|
||||
RangeParts res = makeGenericRangeParts(operandRangesToLoopsMap, ranges);
|
||||
if (!tileSizes.hasValue())
|
||||
return res.makeRanges();
|
||||
SmallVector<Value *, 4> tiledSteps;
|
||||
for (auto z : llvm::zip(res.steps, *tileSizes)) {
|
||||
auto *step = std::get<0>(z);
|
||||
auto tileSize = std::get<1>(z);
|
||||
auto stepValue = step->getDefiningOp()->cast<ConstantIndexOp>().getValue();
|
||||
auto tileSizeValue =
|
||||
tileSize->getDefiningOp()->cast<ConstantIndexOp>().getValue();
|
||||
assert(stepValue > 0);
|
||||
tiledSteps.push_back(constant_index(stepValue * tileSizeValue));
|
||||
}
|
||||
res.steps = tiledSteps;
|
||||
return res.makeRanges();
|
||||
}
|
||||
|
||||
template <class ContractionOp>
|
||||
static SmallVector<mlir::AffineForOp, 4>
|
||||
writeAsLoops(ContractionOp contraction) {
|
||||
ScopedContext scope(mlir::FuncBuilder(contraction.getOperation()),
|
||||
contraction.getLoc());
|
||||
auto loopRanges = makeGenericLoopRanges(operandRangesToLoopsMap(contraction),
|
||||
getRanges(contraction));
|
||||
|
||||
SmallVector<IndexHandle, 4> parallelIvs(contraction.getNumParallelDims());
|
||||
SmallVector<IndexHandle, 4> reductionIvs(contraction.getNumReductionDims());
|
||||
auto pivs = IndexHandle::makeIndexHandlePointers(parallelIvs);
|
||||
auto rivs = IndexHandle::makeIndexHandlePointers(reductionIvs);
|
||||
assert(loopRanges.size() == pivs.size() + rivs.size());
|
||||
|
||||
// clang-format off
|
||||
using linalg::common::LoopNestRangeBuilder;
|
||||
ArrayRef<Value *> ranges(loopRanges);
|
||||
LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))({
|
||||
LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))({
|
||||
[&contraction, ¶llelIvs, &reductionIvs]() {
|
||||
SmallVector<mlir::Value *, 4> parallel(
|
||||
parallelIvs.begin(), parallelIvs.end());
|
||||
SmallVector<mlir::Value *, 4> reduction(
|
||||
reductionIvs.begin(), reductionIvs.end());
|
||||
contraction.emitScalarImplementation(parallel, reduction);
|
||||
/// NestedBuilders expect handles, we thus return an IndexHandle.
|
||||
return IndexHandle();
|
||||
}()
|
||||
})
|
||||
});
|
||||
// clang-format on
|
||||
|
||||
SmallVector<mlir::AffineForOp, 4> res;
|
||||
res.reserve(pivs.size() + rivs.size());
|
||||
for (auto iv : parallelIvs)
|
||||
res.push_back(getForInductionVarOwner(iv.getValue()));
|
||||
for (auto iv : reductionIvs)
|
||||
res.push_back(getForInductionVarOwner(iv.getValue()));
|
||||
return res;
|
||||
}
|
||||
|
||||
void linalg::lowerToLoops(mlir::Function *f) {
|
||||
f->walkPostOrder([](Operation *op) {
|
||||
if (auto matmulOp = op->dyn_cast<linalg::MatmulOp>()) {
|
||||
writeAsLoops(matmulOp);
|
||||
} else if (auto matvecOp = op->dyn_cast<linalg::MatvecOp>()) {
|
||||
writeAsLoops(matvecOp);
|
||||
} else if (auto dotOp = op->dyn_cast<linalg::DotOp>()) {
|
||||
writeAsLoops(dotOp);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
op->erase();
|
||||
});
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue