[mlir] Add pass to convert elementwise ops to linalg.

This patch converts elementwise ops on tensors to linalg.generic ops
with the same elementwise op in the payload (except rewritten to
operate on scalars, obviously). This is a great form for later fusion to
clean up.

E.g.

```
// Compute: %arg0 + %arg1 - %arg2
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
  %0 = addf %arg0, %arg1 : tensor<?xf32>
  %1 = subf %0, %arg2 : tensor<?xf32>
  return %1 : tensor<?xf32>
}
```

Running this through
`mlir-opt -convert-std-to-linalg -linalg-fusion-for-tensor-ops` we get:

```
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
  %0 = linalg.generic {indexing_maps = [#map0, #map0, #map0, #map0], iterator_types = ["parallel"]} ins(%arg0, %arg1, %arg2 : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
    %1 = addf %arg3, %arg4 : f32
    %2 = subf %1, %arg5 : f32
    linalg.yield %2 : f32
  } -> tensor<?xf32>
  return %0 : tensor<?xf32>
}
```

So the elementwise ops on tensors have nicely collapsed into a single
linalg.generic, which is the form we want for further transformations.

Differential Revision: https://reviews.llvm.org/D90354
This commit is contained in:
Sean Silva 2020-10-28 13:25:48 -07:00
parent b4fa28b408
commit 53a0d45db6
6 changed files with 196 additions and 0 deletions

View File

@ -16,6 +16,8 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {
std::unique_ptr<OperationPass<FuncOp>> createConvertElementwiseToLinalgPass();
std::unique_ptr<OperationPass<FuncOp>> createLinalgFoldUnitExtentDimsPass(); std::unique_ptr<OperationPass<FuncOp>> createLinalgFoldUnitExtentDimsPass();
std::unique_ptr<Pass> createLinalgFusionOfTensorOpsPass(); std::unique_ptr<Pass> createLinalgFusionOfTensorOpsPass();
@ -48,6 +50,11 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
/// buffers instead. /// buffers instead.
std::unique_ptr<OperationPass<ModuleOp>> createLinalgBufferizePass(); std::unique_ptr<OperationPass<ModuleOp>> createLinalgBufferizePass();
/// Populate patterns that convert `ElementwiseMappable` ops to linalg
/// parallel loops.
void populateElementwiseToLinalgConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx);
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
/// producer (consumer) generic operation by expanding the dimensionality of the /// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op. /// loop in the generic op.

View File

@ -11,6 +11,17 @@
include "mlir/Pass/PassBase.td" include "mlir/Pass/PassBase.td"
def ConvertElementwiseToLinalg : FunctionPass<"convert-elementwise-to-linalg"> {
let summary = "Convert ElementwiseMappable ops to linalg";
let description = [{
Convert ops with the `ElementwiseMappable` trait to linalg parallel loops.
This pass only converts ops that operate on ranked tensors.
}];
let constructor = "mlir::createConvertElementwiseToLinalgPass()";
let dependentDialects = ["linalg::LinalgDialect"];
}
def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> { def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
let summary = "Remove unit-extent dimension in Linalg ops on tensors"; let summary = "Remove unit-extent dimension in Linalg ops on tensors";
let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()"; let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()";

View File

@ -0,0 +1,19 @@
// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
func @main() {
%a = constant dense<[1.0, 2.0, 3.0]> : tensor<3xf32>
%b = constant dense<[10.0, 20.0, 30.0]> : tensor<3xf32>
%addf = addf %a, %b : tensor<3xf32>
%addf_unranked = tensor_cast %addf : tensor<3xf32> to tensor<*xf32>
call @print_memref_f32(%addf_unranked) : (tensor<*xf32>) -> ()
// CHECK: Unranked Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [3] strides = [1] data =
// CHECK-NEXT: [11, 22, 33]
return
}
func @print_memref_f32(%ptr : tensor<*xf32>)

View File

@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Bufferize.cpp Bufferize.cpp
CodegenStrategy.cpp CodegenStrategy.cpp
DropUnitDims.cpp DropUnitDims.cpp
ElementwiseToLinalg.cpp
Fusion.cpp Fusion.cpp
FusionOnTensors.cpp FusionOnTensors.cpp
Hoisting.cpp Hoisting.cpp

View File

@ -0,0 +1,98 @@
//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Passes.h"
#include "PassDetail.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
if (!op->hasTrait<OpTrait::ElementwiseMappable>())
return false;
// TODO: The conversion pattern can be made to work for `any_of` here, but
// it's more complex as it requires tracking which operands are scalars.
return llvm::all_of(op->getOperandTypes(),
[](Type type) { return type.isa<RankedTensorType>(); });
}
namespace {
struct ConvertStdElementwiseOpOnRankedTensors : public RewritePattern {
ConvertStdElementwiseOpOnRankedTensors()
: RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
if (!isElementwiseMappableOpOnRankedTensors(op))
return rewriter.notifyMatchFailure(
op, "requires elementwise op on ranked tensors");
auto rank = op->getResult(0).getType().cast<RankedTensorType>().getRank();
SmallVector<AffineMap, 3> indexingMaps(
op->getNumResults() + op->getNumOperands(),
rewriter.getMultiDimIdentityMap(rank));
SmallVector<StringRef, 6> iteratorTypes(rank,
getParallelIteratorTypeName());
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, /*resultTensorTypes=*/op->getResultTypes(),
/*inputs=*/op->getOperands(),
/*outputBuffers=*/ValueRange(),
/*initTensors=*/ValueRange(),
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
/*bodyBuilder=*/
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
OperationState state(loc, op->getName());
state.addAttributes(op->getAttrs());
state.addOperands(regionArgs);
auto resultTypes = llvm::to_vector<6>(
llvm::map_range(op->getResultTypes(), [](Type type) {
return type.cast<TensorType>().getElementType();
}));
state.addTypes(resultTypes);
auto *scalarOp = builder.createOperation(state);
builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
});
return success();
}
};
} // namespace
void mlir::populateElementwiseToLinalgConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *) {
patterns.insert<ConvertStdElementwiseOpOnRankedTensors>();
}
namespace {
class ConvertElementwiseToLinalgPass
: public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> {
void runOnFunction() final {
auto func = getOperation();
auto *context = &getContext();
ConversionTarget target(*context);
OwningRewritePatternList patterns;
populateElementwiseToLinalgConversionPatterns(patterns, context);
target.markUnknownOpDynamicallyLegal([](Operation *op) {
return !isElementwiseMappableOpOnRankedTensors(op);
});
if (failed(applyPartialConversion(func, target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::createConvertElementwiseToLinalgPass() {
return std::make_unique<ConvertElementwiseToLinalgPass>();
}

View File

@ -0,0 +1,60 @@
// RUN: mlir-opt -convert-elementwise-to-linalg -split-input-file %s | FileCheck %s
// In-depth checking of the linalg.generic op for a very trivial case.
// CHECK: #map = affine_map<() -> ()>
// CHECK-LABEL: func @addf_rank0
func @addf_rank0(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: %{{.*}} = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%{{.*}}, %{{.*}} : tensor<f32>, tensor<f32>) {
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
// CHECK: %[[YIELD:.*]] = addf %[[LHS]], %[[RHS]] : f32
// CHECK: linalg.yield %[[YIELD]] : f32
// CHECK: } -> tensor<f32>
%0 = addf %arg0, %arg1 : tensor<f32>
return %0 : tensor<f32>
}
// -----
// Check indexing maps and iterator types for the rank > 0 case.
// CHECK: #map = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func @addf_rank1
func @addf_rank1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: linalg.generic{{.*}}indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]
%0 = addf %arg0, %arg1 : tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// Check a unary op.
// CHECK-LABEL: func @exp
func @exp(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: linalg.generic
// CHECK: ^bb0(%[[SCALAR:.*]]: f32):
// CHECK: %[[YIELD:.*]] = exp %[[SCALAR]] : f32
// CHECK: linalg.yield %[[YIELD]] : f32
%0 = exp %arg0 : tensor<f32>
return %0 : tensor<f32>
}
// -----
// Check a case with varying operand types.
// CHECK-LABEL: func @select
func @select(%arg0: tensor<i1>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i32> {
// CHECK: linalg.generic
// CHECK: ^bb0(%[[PRED:.*]]: i1, %[[TRUE_VAL:.*]]: i32, %[[FALSE_VAL:.*]]: i32):
// CHECK: select %[[PRED]], %[[TRUE_VAL]], %[[FALSE_VAL]] : i32
%0 = select %arg0, %arg1, %arg2 : tensor<i1>, tensor<i32>
return %0 : tensor<i32>
}
// -----
// Spot-check an op that requires copying attributes properly to the created scalar op.
// CHECK-LABEL: func @cmpf(
func @cmpf(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<i1> {
// CHECK: cmpf "olt", %{{.*}}, %{{.*}} : f32
%0 = cmpf "olt", %arg0, %arg1 : tensor<f32>
return %0 : tensor<i1>
}