[mlir][Linalg] Add canonicalization of linalg op -> dim op.

Add canonicalization to replace use of the result of a linalg
operation on tensors in a dim operation, to use one of the operands of
the linalg operations instead. This allows the linalg op itself to be
deleted when all its non-dim uses are removed (say through tiling, etc.)

Differential Revision: https://reviews.llvm.org/D93076
This commit is contained in:
MaheshRavishankar 2021-01-14 15:41:04 -08:00
parent 547b032ccc
commit 774c9c6ef3
7 changed files with 344 additions and 47 deletions

View File

@ -32,6 +32,9 @@ def Linalg_Dialect : Dialect {
the op semantics.
}];
let cppNamespace = "::mlir::linalg";
let dependentDialects = [
"AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
];
}
// Whether a type is a RangeType.

View File

@ -946,6 +946,56 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
return inversePermutation(getLoopsToShapesMap());
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the position in the results of the affine map computed
by getLoopsToShapesMap() that represents the shape of an
operand (input or output) at a dimension.
}],
/*retTy=*/"Optional<unsigned>",
/*methodName=*/"getOperandDimPositionInLoopsToShapeMap",
/*args=*/(ins "unsigned":$operandIdx, "unsigned":$dim),
/*methodBody=*/"",
/*defaultImplementation=*/[{
unsigned pos = 0;
for (auto type : llvm::enumerate(getShapedOperandTypes())) {
if (type.index() == operandIdx) return pos + dim;
pos += type.value().getRank();
}
return {};
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the position in the results of the affine map computed
by getLoopsToShapesMap() that represents the shape of an
input operand at a dimension.
}],
/*retTy=*/"Optional<unsigned>",
/*methodName=*/"getInputValueDimPositionInLoopsToShapeMap",
/*args=*/(ins "unsigned":$inputIdx, "unsigned":$dim),
/*methodBody=*/"",
/*defaultImplementation=*/[{
if (inputIdx >= getNumInputs()) return {};
return getOperandDimPositionInLoopsToShapeMap(inputIdx, dim);
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the position in the results of the affine map computed
by getLoopsToShapesMap() that represents the shape of the
result value at a dimension.
}],
/*retTy=*/"Optional<unsigned>",
/*methodName=*/"getResultValueDimPositionInLoopsToShapeMap",
/*args=*/(ins "unsigned":$resultIdx, "unsigned":$dim),
/*methodBody=*/"",
/*defaultImplementation=*/[{
if (resultIdx >= getNumOutputs()) return {};
return getOperandDimPositionInLoopsToShapeMap(
getNumInputs() + resultIdx, dim);
}]
>,
//===------------------------------------------------------------------===//
// Other static interface methods.
@ -1027,6 +1077,12 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
}
return res;
}
/// Returns the value that expresses the shape of the output in terms of
/// shape of the input operands where possible
Optional<Value> inferResultDimFromInputShapes
(OpBuilder &b, Location loc, unsigned resultIdx, unsigned im);
//========================================================================//
// Helper functions to mutate the `operand_segment_sizes` attribute.
// These are useful when cloning and changing operand types.

View File

@ -9,6 +9,9 @@
#ifndef MLIR_DIALECT_LINALG_LINALGTYPES_H_
#define MLIR_DIALECT_LINALG_LINALGTYPES_H_
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Types.h"

View File

@ -159,29 +159,29 @@ public:
// Default visit methods. Note that the default op-specific binary op visit
// methods call the general visitAffineBinaryOpExpr visit method.
void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {}
void visitAddExpr(AffineBinaryOpExpr expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
RetTy visitAddExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
void visitMulExpr(AffineBinaryOpExpr expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
RetTy visitMulExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
void visitModExpr(AffineBinaryOpExpr expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
RetTy visitModExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
void visitFloorDivExpr(AffineBinaryOpExpr expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
void visitCeilDivExpr(AffineBinaryOpExpr expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
void visitConstantExpr(AffineConstantExpr expr) {}
void visitDimExpr(AffineDimExpr expr) {}
void visitSymbolExpr(AffineSymbolExpr expr) {}
RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
private:
// Walk the operands - each operand is itself walked in post order.
void walkOperandsPostOrder(AffineBinaryOpExpr expr) {
RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
walkPostOrder(expr.getLHS());
walkPostOrder(expr.getRHS());
}

View File

@ -16,12 +16,14 @@
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
@ -86,6 +88,82 @@ SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
return res;
}
/// Visitor to check if any of the given set of positions from AffineDimExprs
/// are used within an AffineExpr.
struct HasAffineDimExprVisitor
: public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
HasAffineDimExprVisitor(llvm::SmallSet<unsigned, 4> &positions)
: positions(positions) {}
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
}
bool visitDimExpr(AffineDimExpr dimExpr) {
return positions.count(dimExpr.getPosition());
}
bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
private:
llvm::SmallSet<unsigned, 4> positions;
};
Optional<Value> LinalgOp::inferResultDimFromInputShapes(OpBuilder &b,
Location loc,
unsigned resultIdx,
unsigned dim) {
// An example that helps understand the logic below.
// Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
// We want to express the shape of dim 0 of O in terms of shape of the inputs.
// This is achieved as follows.
// loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
// subMapOfResultDim = (d0, d1, d2) -> (d0 + d1)
// shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
// resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap)
// = (d0, d1, d2, d3, d4, d5) -> (d0 + d1)
AffineMap loopsToShapesMap = getLoopsToShapesMap();
// Find the position in the above map that represents the shape of the
// result:dim being inferred.
Optional<unsigned> resultDimSubMapPos =
getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim);
if (!resultDimSubMapPos)
return {};
/// From loopsToShapesMap extract the submap that represents the shape of the
/// (resultIdx, dim) needed
AffineMap loopToResultDimShapeMap =
loopsToShapesMap.getSubMap(*resultDimSubMapPos);
AffineMap operandShapesToResultDimMap =
loopToResultDimShapeMap.compose(getShapesToLoopsMap());
// Check that the result dim map does not contain the positions corresponding
// to the outputs.
llvm::SmallSet<unsigned, 4> outputDims;
unsigned outputDimPosStart =
getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue();
unsigned outputDimPosEnd =
getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1,
getOutputOpOperands()
.back()
.get()
.getType()
.cast<ShapedType>()
.getRank() -
1)
.getValue();
llvm::for_each(llvm::seq<unsigned>(outputDimPosStart, outputDimPosEnd),
[&outputDims](unsigned dim) { outputDims.insert(dim); });
HasAffineDimExprVisitor checkDimExpr(outputDims);
if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0)))
return llvm::None;
return applyMapToValues(b, loc, operandShapesToResultDimMap,
createFlatListOfOperandDims(b, loc))[0];
}
/// Forward declarations.
template <typename NamedStructuredOpType>
static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
@ -2022,6 +2100,49 @@ struct FoldTensorCastOp : public RewritePattern {
return success();
}
};
/// Replaces std.dim operations that use the result of a LinalgOp (on tensors)
/// with std.dim operations that use one of the arguments. For example,
///
/// %0 = linalg.matmul ins(%arg0, %arg1, ...)
/// %1 = dim %0, %c0
///
/// with
///
/// %1 = dim %arg0, %c0
///
/// where possible. With this the result of the `linalg.matmul` is not used in
/// dim operations. If the value produced is replaced with another value (say by
/// tiling `linalg.matmul`) will make the `linalg.matmul` truly dead instead of
/// used in a dim op that would prevent the DCE of this op.
struct ReplaceDimOfLinalgOpResult : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DimOp dimOp,
PatternRewriter &rewriter) const override {
Value dimValue = dimOp.memrefOrTensor();
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
if (!dimIndex)
return failure();
auto linalgOp = dimValue.getDefiningOp<LinalgOp>();
if (!linalgOp)
return failure();
unsigned resultIndex = dimValue.cast<OpResult>().getResultNumber();
Optional<Value> operandDimValue = linalgOp.inferResultDimFromInputShapes(
rewriter, dimOp.getLoc(), resultIndex,
static_cast<unsigned>(*dimIndex));
if (!operandDimValue) {
// Its always possible to replace using the corresponding `outs`
// parameter.
operandDimValue = rewriter.create<DimOp>(
dimOp.getLoc(), linalgOp.getOutput(resultIndex), *dimIndex);
}
rewriter.replaceOp(dimOp, *operandDimValue);
return success();
}
};
} // namespace
namespace {
@ -2166,26 +2287,6 @@ struct RemoveIdentityLinalgOps : public RewritePattern {
return success();
}
};
/// Canonicalize a `linalgOp` -> `dim` pattern by replacing the `dim` arg
/// with the corresponding output tensor argument of the linalg op.
struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DimOp dimOp,
PatternRewriter &rewriter) const override {
Value dimOpArg = dimOp.memrefOrTensor();
auto linalgOp = dimOpArg.getDefiningOp<LinalgOp>();
if (!linalgOp)
return failure();
auto results = linalgOp.getOperation()->getResults();
int64_t id = std::distance(results.begin(), llvm::find(results, dimOpArg));
auto outputTensors = linalgOp.getOutputTensors();
rewriter.replaceOpWithNewOp<DimOp>(dimOp, outputTensors[id], dimOp.index());
return success();
}
};
} // namespace
#define CANONICALIZERS_AND_FOLDERS(XXX) \
@ -2193,7 +2294,7 @@ struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
MLIRContext *context) { \
results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp, \
RemoveIdentityLinalgOps>(); \
results.insert<ReplaceDimOfLinalgResult>(context); \
results.insert<ReplaceDimOfLinalgOpResult>(context); \
} \
\
LogicalResult XXX::fold(ArrayRef<Attribute>, \

View File

@ -58,9 +58,6 @@ struct LinalgInlinerInterface : public DialectInlinerInterface {
//===----------------------------------------------------------------------===//
void mlir::linalg::LinalgDialect::initialize() {
getContext()->getOrLoadDialect("std");
getContext()->getOrLoadDialect("tensor");
addTypes<RangeType>();
addOperations<
#define GET_OP_LIST

View File

@ -390,10 +390,147 @@ func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
// -----
func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
%1 = dim %0, %c0 : tensor<?x?xf32>
%2 = dim %0, %c1 : tensor<?x?xf32>
return %1, %2 : index, index
}
// CHECK: func @init_tensor_dynamic_dim2
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK: return %[[ARG0]], %[[ARG1]]
// -----
func @remove_dim_result_uses
(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) -> (index) {
%c0 = constant 0 : index
%0 = linalg.generic
{indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0 + d1, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2 : tensor<?x?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
%1 = mulf %arg3, %arg4 : f32
%2 = addf %1, %arg5 : f32
linalg.yield %2 : f32
} -> tensor<?x?xf32>
%3 = dim %0, %c0 : tensor<?x?xf32>
return %3 : index
}
// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
// CHECK: func @remove_dim_result_uses
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[T0:.+]] = dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[T1:.+]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[T2:.+]] = affine.apply #[[MAP]]()[%[[T0]], %[[T1]]]
// CHECK: return %[[T2]]
// -----
func @remove_dim_result_uses_outs
(%arg0 : tensor<?xf32>, %arg1 : index) -> (index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%d0 = dim %arg0, %c0 : tensor<?xf32>
%0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
%1 = linalg.generic
{indexing_maps = [affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%arg2: f32, %arg3: f32) :
linalg.yield %arg2 : f32
} -> tensor<?x?xf32>
%2 = dim %1, %c1 : tensor<?x?xf32>
return %2 : index
}
// CHECK: func @remove_dim_result_uses_outs
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK: return %[[ARG1]]
// -----
func @remove_dim_result_uses_sequence
(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) -> (index, index, index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = dim %0, %c0 : tensor<?x?xf32>
%2 = dim %0, %c1 : tensor<?x?xf32>
%3 = linalg.generic
{indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>,
affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d0, d2)>],
iterator_types = ["parallel", "reduction", "parallel"]}
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%0 : tensor<?x?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
%4 = mulf %arg3, %arg4 : f32
%5 = addf %4, %arg5 : f32
linalg.yield %5 : f32
} -> tensor<?x?xf32>
%6 = dim %3, %c0 : tensor<?x?xf32>
%7 = dim %3, %c1 : tensor<?x?xf32>
return %1, %2, %6, %7 : index, index, index, index
}
// CHECK-LABEL: func @remove_dim_result_uses_sequence
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[T0:.+]] = dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[T1:.+]] = dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[T2:.+]] = dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[T3:.+]] = dim %[[ARG1]], %[[C1]]
// CHECK: return %[[T0]], %[[T1]], %[[T2]], %[[T3]]
// -----
func @keep_result_dim_uses_sequence2
(%arg0 : tensor<?xf32>, %arg1 : index) -> (index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%d0 = dim %arg0, %c0 : tensor<?xf32>
%0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
%1 = linalg.generic
{indexing_maps = [affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%arg2: f32, %arg3 : f32):
linalg.yield %arg2 : f32
} -> tensor<?x?xf32>
%2 = dim %1, %c0 : tensor<?x?xf32>
%3 = dim %1, %c1 : tensor<?x?xf32>
return %2, %3 : index, index
}
// CHECK: func @keep_result_dim_uses_sequence2
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[T0:.+]] = dim %[[ARG0]], %[[C0]]
// CHECK: return %[[T0]], %[[ARG1]]
// -----
#map = affine_map<(d0) -> (d0)>
func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
%arg_1: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
%arg_1: tensor<?xf32>) -> (index, index) {
%0, %1 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel"]
@ -405,16 +542,16 @@ func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
%c0 = constant 0 : index
%num_elem_0 = dim %0, %c0 : tensor<?xf32>
%result_0 = linalg.init_tensor [%num_elem_0] : tensor<?xf32>
%num_elem_1 = dim %1, %c0 : tensor<?xf32>
%result_1 = linalg.init_tensor [%num_elem_1] : tensor<?xf32>
return %result_0, %result_1 : tensor<?xf32>, tensor<?xf32>
return %num_elem_0, %num_elem_1 : index, index
}
// CHECK-LABEL: func @init_tensor_dim_of_linalg_result(
// CHECK-SAME: [[ARG_0:%.*]]: tensor<?xf32>, [[ARG_1:%.*]]: tensor<?xf32>)
// CHECK: dim [[ARG_0]]
// CHECK: dim [[ARG_1]]
// CHECK: func @init_tensor_dim_of_linalg_result(
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: tensor<?xf32>)
// CHECK: %[[R0:.+]] = dim %[[ARG_0]]
// CHECK: %[[R1:.+]] = dim %[[ARG_0]]
// CHECK: return %[[R0]], %[[R1]]
// -----