forked from OSchip/llvm-project
[mlir] Add pass to convert elementwise ops to linalg.
This patch converts elementwise ops on tensors to linalg.generic ops with the same elementwise op in the payload (except rewritten to operate on scalars, obviously). This is a great form for later fusion to clean up. E.g. ``` // Compute: %arg0 + %arg1 - %arg2 func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> { %0 = addf %arg0, %arg1 : tensor<?xf32> %1 = subf %0, %arg2 : tensor<?xf32> return %1 : tensor<?xf32> } ``` Running this through `mlir-opt -convert-std-to-linalg -linalg-fusion-for-tensor-ops` we get: ``` func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> { %0 = linalg.generic {indexing_maps = [#map0, #map0, #map0, #map0], iterator_types = ["parallel"]} ins(%arg0, %arg1, %arg2 : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) { ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors %1 = addf %arg3, %arg4 : f32 %2 = subf %1, %arg5 : f32 linalg.yield %2 : f32 } -> tensor<?xf32> return %0 : tensor<?xf32> } ``` So the elementwise ops on tensors have nicely collapsed into a single linalg.generic, which is the form we want for further transformations. Differential Revision: https://reviews.llvm.org/D90354
This commit is contained in:
parent
b4fa28b408
commit
53a0d45db6
|
@ -16,6 +16,8 @@
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createConvertElementwiseToLinalgPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLinalgFoldUnitExtentDimsPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLinalgFoldUnitExtentDimsPass();
|
||||||
|
|
||||||
std::unique_ptr<Pass> createLinalgFusionOfTensorOpsPass();
|
std::unique_ptr<Pass> createLinalgFusionOfTensorOpsPass();
|
||||||
|
@ -48,6 +50,11 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
|
||||||
/// buffers instead.
|
/// buffers instead.
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createLinalgBufferizePass();
|
std::unique_ptr<OperationPass<ModuleOp>> createLinalgBufferizePass();
|
||||||
|
|
||||||
|
/// Populate patterns that convert `ElementwiseMappable` ops to linalg
|
||||||
|
/// parallel loops.
|
||||||
|
void populateElementwiseToLinalgConversionPatterns(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
|
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
|
||||||
/// producer (consumer) generic operation by expanding the dimensionality of the
|
/// producer (consumer) generic operation by expanding the dimensionality of the
|
||||||
/// loop in the generic op.
|
/// loop in the generic op.
|
||||||
|
|
|
@ -11,6 +11,17 @@
|
||||||
|
|
||||||
include "mlir/Pass/PassBase.td"
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
|
def ConvertElementwiseToLinalg : FunctionPass<"convert-elementwise-to-linalg"> {
|
||||||
|
let summary = "Convert ElementwiseMappable ops to linalg";
|
||||||
|
let description = [{
|
||||||
|
Convert ops with the `ElementwiseMappable` trait to linalg parallel loops.
|
||||||
|
|
||||||
|
This pass only converts ops that operate on ranked tensors.
|
||||||
|
}];
|
||||||
|
let constructor = "mlir::createConvertElementwiseToLinalgPass()";
|
||||||
|
let dependentDialects = ["linalg::LinalgDialect"];
|
||||||
|
}
|
||||||
|
|
||||||
def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
|
def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
|
||||||
let summary = "Remove unit-extent dimension in Linalg ops on tensors";
|
let summary = "Remove unit-extent dimension in Linalg ops on tensors";
|
||||||
let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()";
|
let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()";
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
|
||||||
|
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
|
||||||
|
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
|
||||||
|
// RUN: | FileCheck %s
|
||||||
|
|
||||||
|
func @main() {
|
||||||
|
%a = constant dense<[1.0, 2.0, 3.0]> : tensor<3xf32>
|
||||||
|
%b = constant dense<[10.0, 20.0, 30.0]> : tensor<3xf32>
|
||||||
|
|
||||||
|
%addf = addf %a, %b : tensor<3xf32>
|
||||||
|
%addf_unranked = tensor_cast %addf : tensor<3xf32> to tensor<*xf32>
|
||||||
|
call @print_memref_f32(%addf_unranked) : (tensor<*xf32>) -> ()
|
||||||
|
// CHECK: Unranked Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [3] strides = [1] data =
|
||||||
|
// CHECK-NEXT: [11, 22, 33]
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func @print_memref_f32(%ptr : tensor<*xf32>)
|
|
@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
||||||
Bufferize.cpp
|
Bufferize.cpp
|
||||||
CodegenStrategy.cpp
|
CodegenStrategy.cpp
|
||||||
DropUnitDims.cpp
|
DropUnitDims.cpp
|
||||||
|
ElementwiseToLinalg.cpp
|
||||||
Fusion.cpp
|
Fusion.cpp
|
||||||
FusionOnTensors.cpp
|
FusionOnTensors.cpp
|
||||||
Hoisting.cpp
|
Hoisting.cpp
|
||||||
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Linalg/Passes.h"
|
||||||
|
|
||||||
|
#include "PassDetail.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
|
||||||
|
if (!op->hasTrait<OpTrait::ElementwiseMappable>())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// TODO: The conversion pattern can be made to work for `any_of` here, but
|
||||||
|
// it's more complex as it requires tracking which operands are scalars.
|
||||||
|
return llvm::all_of(op->getOperandTypes(),
|
||||||
|
[](Type type) { return type.isa<RankedTensorType>(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct ConvertStdElementwiseOpOnRankedTensors : public RewritePattern {
|
||||||
|
ConvertStdElementwiseOpOnRankedTensors()
|
||||||
|
: RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
|
||||||
|
LogicalResult matchAndRewrite(Operation *op,
|
||||||
|
PatternRewriter &rewriter) const final {
|
||||||
|
if (!isElementwiseMappableOpOnRankedTensors(op))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "requires elementwise op on ranked tensors");
|
||||||
|
|
||||||
|
auto rank = op->getResult(0).getType().cast<RankedTensorType>().getRank();
|
||||||
|
SmallVector<AffineMap, 3> indexingMaps(
|
||||||
|
op->getNumResults() + op->getNumOperands(),
|
||||||
|
rewriter.getMultiDimIdentityMap(rank));
|
||||||
|
SmallVector<StringRef, 6> iteratorTypes(rank,
|
||||||
|
getParallelIteratorTypeName());
|
||||||
|
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
|
||||||
|
op, /*resultTensorTypes=*/op->getResultTypes(),
|
||||||
|
/*inputs=*/op->getOperands(),
|
||||||
|
/*outputBuffers=*/ValueRange(),
|
||||||
|
/*initTensors=*/ValueRange(),
|
||||||
|
/*indexingMaps=*/indexingMaps,
|
||||||
|
/*iteratorTypes=*/iteratorTypes,
|
||||||
|
/*bodyBuilder=*/
|
||||||
|
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
|
||||||
|
OperationState state(loc, op->getName());
|
||||||
|
state.addAttributes(op->getAttrs());
|
||||||
|
state.addOperands(regionArgs);
|
||||||
|
auto resultTypes = llvm::to_vector<6>(
|
||||||
|
llvm::map_range(op->getResultTypes(), [](Type type) {
|
||||||
|
return type.cast<TensorType>().getElementType();
|
||||||
|
}));
|
||||||
|
state.addTypes(resultTypes);
|
||||||
|
auto *scalarOp = builder.createOperation(state);
|
||||||
|
builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
|
||||||
|
});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void mlir::populateElementwiseToLinalgConversionPatterns(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *) {
|
||||||
|
patterns.insert<ConvertStdElementwiseOpOnRankedTensors>();
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertElementwiseToLinalgPass
|
||||||
|
: public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> {
|
||||||
|
|
||||||
|
void runOnFunction() final {
|
||||||
|
auto func = getOperation();
|
||||||
|
auto *context = &getContext();
|
||||||
|
ConversionTarget target(*context);
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
|
||||||
|
populateElementwiseToLinalgConversionPatterns(patterns, context);
|
||||||
|
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
||||||
|
return !isElementwiseMappableOpOnRankedTensors(op);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (failed(applyPartialConversion(func, target, std::move(patterns))))
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
mlir::createConvertElementwiseToLinalgPass() {
|
||||||
|
return std::make_unique<ConvertElementwiseToLinalgPass>();
|
||||||
|
}
|
|
@ -0,0 +1,60 @@
|
||||||
|
// RUN: mlir-opt -convert-elementwise-to-linalg -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
|
// In-depth checking of the linalg.generic op for a very trivial case.
|
||||||
|
// CHECK: #map = affine_map<() -> ()>
|
||||||
|
// CHECK-LABEL: func @addf_rank0
|
||||||
|
func @addf_rank0(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
|
||||||
|
// CHECK: %{{.*}} = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%{{.*}}, %{{.*}} : tensor<f32>, tensor<f32>) {
|
||||||
|
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
|
||||||
|
// CHECK: %[[YIELD:.*]] = addf %[[LHS]], %[[RHS]] : f32
|
||||||
|
// CHECK: linalg.yield %[[YIELD]] : f32
|
||||||
|
// CHECK: } -> tensor<f32>
|
||||||
|
%0 = addf %arg0, %arg1 : tensor<f32>
|
||||||
|
return %0 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Check indexing maps and iterator types for the rank > 0 case.
|
||||||
|
// CHECK: #map = affine_map<(d0) -> (d0)>
|
||||||
|
// CHECK-LABEL: func @addf_rank1
|
||||||
|
func @addf_rank1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
// CHECK: linalg.generic{{.*}}indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]
|
||||||
|
%0 = addf %arg0, %arg1 : tensor<?xf32>
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Check a unary op.
|
||||||
|
// CHECK-LABEL: func @exp
|
||||||
|
func @exp(%arg0: tensor<f32>) -> tensor<f32> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: ^bb0(%[[SCALAR:.*]]: f32):
|
||||||
|
// CHECK: %[[YIELD:.*]] = exp %[[SCALAR]] : f32
|
||||||
|
// CHECK: linalg.yield %[[YIELD]] : f32
|
||||||
|
%0 = exp %arg0 : tensor<f32>
|
||||||
|
return %0 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Check a case with varying operand types.
|
||||||
|
// CHECK-LABEL: func @select
|
||||||
|
func @select(%arg0: tensor<i1>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i32> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: ^bb0(%[[PRED:.*]]: i1, %[[TRUE_VAL:.*]]: i32, %[[FALSE_VAL:.*]]: i32):
|
||||||
|
// CHECK: select %[[PRED]], %[[TRUE_VAL]], %[[FALSE_VAL]] : i32
|
||||||
|
%0 = select %arg0, %arg1, %arg2 : tensor<i1>, tensor<i32>
|
||||||
|
return %0 : tensor<i32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Spot-check an op that requires copying attributes properly to the created scalar op.
|
||||||
|
// CHECK-LABEL: func @cmpf(
|
||||||
|
func @cmpf(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<i1> {
|
||||||
|
// CHECK: cmpf "olt", %{{.*}}, %{{.*}} : f32
|
||||||
|
%0 = cmpf "olt", %arg0, %arg1 : tensor<f32>
|
||||||
|
return %0 : tensor<i1>
|
||||||
|
}
|
Loading…
Reference in New Issue