[mlir] Enhance InferShapedTypeOpInterface and move LinalgOps to use them.

A new `InterfaceMethod` is added to `InferShapedTypeOpInterface` that
allows an operation to return the `Value`s for each dim of its
results. It is intended for the case where the `Value` returned for
each dim is computed using the operands and operation attributes. This
interface method is for cases where the result dim of an operation can
be computed independently, and it avoids the need to aggregate all
dims of a result into a single shape value. This also implies that
this is not suitable for cases where the result type is unranked (for
which the existing interface methods is to be used).

Also added is a canonicalization pattern that uses this interface and
resolves the shapes of the output in terms of the shapes of the
inputs. Moving Linalg ops to use this interface, so that many
canonicalization patterns implemented for individual linalg ops to
achieve the same result can be removed in favor of the added
canonicalization pattern.

Differential Revision: https://reviews.llvm.org/D97887
This commit is contained in:
MaheshRavishankar 2021-03-29 10:57:23 -07:00
parent 742f663705
commit 9b0517035f
15 changed files with 395 additions and 209 deletions

View File

@ -1087,18 +1087,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
>,
InterfaceMethod<
/*desc=*/[{
Return the position in the results of the affine map computed
by getLoopsToShapesMap() that represents the shape of the
result value at a dimension.
Return the range of position in the result of the affine map
computed by getLoopsToShapesMap() which correspond to the
AffineExprs used to access the outputs of the operation.
}],
/*retTy=*/"Optional<unsigned>",
/*methodName=*/"getResultValueDimPositionInLoopsToShapeMap",
/*args=*/(ins "unsigned":$resultIdx, "unsigned":$dim),
/*retTy=*/"std::pair<unsigned, unsigned>",
/*methodName=*/"getResultsPositionInLoopsToShapeMap",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
if (resultIdx >= getNumOutputs()) return {};
return getOperandDimPositionInLoopsToShapeMap(
getNumInputs() + resultIdx, dim);
return
{*getOperandDimPositionInLoopsToShapeMap(getNumInputs(), 0),
(*getOperandDimPositionInLoopsToShapeMap
(getNumInputs() + getNumOutputs() - 1,
getOutputShapedType(getNumOutputs()-1).getRank() - 1)) + 1};
}]
>,
InterfaceMethod<
@ -1226,8 +1228,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/// Returns the value that expresses the shape of the output in terms of
/// shape of the input operands where possible
Optional<Value> inferResultDimFromInputShapes
(OpBuilder &b, Location loc, unsigned resultIdx, unsigned im);
LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,
SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes);
//========================================================================//
// Helper functions to mutate the `operand_segment_sizes` attribute.

View File

@ -22,6 +22,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
@ -107,13 +108,6 @@ SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
void getDimsOfType(Operation *op, StringRef iteratorTypeName,
SmallVectorImpl<AffineExpr> &res);
/// For reshape operation, compute the shape of the output based on the result
/// type and shape of the input.
SmallVector<Value, 4>
getReshapeOutputShapeFromInputShape(OpBuilder &b, Location loc, Value src,
ArrayRef<int64_t> dstStaticShape,
ArrayRef<AffineMap> reassociation);
namespace detail {
LogicalResult verifyStructuredOpInterface(Operation *op);
} // namespace detail

View File

@ -15,6 +15,7 @@
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
@ -33,7 +34,10 @@ class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
let parser = [{ return ::parse$cppClass(parser, result); }];
}
def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> {
def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
[NoSideEffect,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapesPerResultDim"]>]> {
let summary = "operation to define a tensor of particular value";
let description = [{
@ -126,7 +130,10 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> {
}
def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
[AttrSizedOperandSegments, NoSideEffect]> {
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapesPerResultDim"]>,
NoSideEffect]> {
let summary = "tensor pad operation";
let description = [{
`linalg.pad_tensor` is an operation that pads the `source` tensor
@ -348,11 +355,6 @@ class Linalg_ReshapeLikeOp<string mnemonic, list<OpTrait> traits = []> :
a.cast<AffineMapAttr>().getValue().getResults());
}));
}
SmallVector<Value, 4> getOutputShape(OpBuilder &b, Location loc) {
return getReshapeOutputShapeFromInputShape(
b, loc, src(), getResultType().getShape(),
getReassociationMaps());
}
}];
let assemblyFormat = [{
$src $reassociation attr-dict `:` type($src) `into` type(results)
@ -417,7 +419,10 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",
let hasCanonicalizer = 1;
}
def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<
"tensor_reshape",
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapesPerResultDim"]>]>,
Arguments<(ins AnyTensor:$src,
AffineMapArrayAttr:$reassociation)>,
Results<(outs AnyTensor:$result)> {

View File

@ -17,6 +17,7 @@
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
// Base Tablegen class for Linalg ops.
@ -25,7 +26,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
// depending on the specific Linalg op.
class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
: Op<Linalg_Dialect, mnemonic, !listconcat(props, [
LinalgStructuredInterface])> {
LinalgStructuredInterface, InferShapedTypeOpInterface])> {
code structuredOpsBaseDecls = [{
// Return the number of induction variables in the basic block. This should
// always be 0 for index-free linalg ops. For IndexedGeneric, this must be
@ -33,6 +34,12 @@ class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
unsigned getNumPayloadInductionVariables() {
return isa<IndexedGenericOp>(this->getOperation()) ? getNumLoops() : 0;
}
LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,
SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
return cast<LinalgOp>(getOperation()).reifyReturnTypeShapesPerResultDim(b,
reifiedReturnShapes);
}
}];
}

View File

@ -97,21 +97,53 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
"::mlir::DictionaryAttr":$attributes,
"::mlir::RegionRange":$regions,
"::mlir::SmallVectorImpl<::mlir::ShapedTypeComponents>&":
$inferredReturnShapes)
$inferredReturnShapes),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{ return ::mlir::failure(); }]
>,
InterfaceMethod<
/*desc=*/[{Reify the shape computation for the operation.
Insert operations using the given OpBuilder that computes the result
shape.
Insert operations using the given OpBuilder that computes the
result shape. Only one of this method or
`reifyReturnTypeShapesPerResultDim` needs to be overriden by the
operation.
}],
/*retTy=*/"::mlir::LogicalResult",
/*methodName=*/"reifyReturnTypeShapes",
/*args=*/(ins "::mlir::OpBuilder&":$builder,
"::mlir::SmallVectorImpl<::mlir::Value>&":$reifiedReturnShapes),
"::mlir::SmallVectorImpl<Value> &":$reifiedReturnShapes),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{ return ::mlir::failure(); }]
>,
InterfaceMethod<
/*desc=*/[{Reify the shape computation for the operation.
Insert operations using the given OpBuilder that computes the
result shape. The `reifiedReturnShapes` is expected to be
populated with as many vectors as the number of results of the
op (empty if the shape of a result value cannot be computed). If
the returned shape for a result is not empty, its size must
match the rank of the shaped type returned. Consequently, this
interface can only be overridden if the return types are ranked.
If both this method and `reifyReturnTypeShapes` are overridden
by the operation, `reifyReturnTypeShapes` takes precedence. This
method is intended to be used when the shape of each result, dim
pair can be computed independently. Using this method avoids
adding additional instructions to aggregate individual dimension
of a result shape into an single `Value` (and consequently
avoids the need to extract the value from the shape on the
client side).
}],
/*retTy=*/"::mlir::LogicalResult",
/*methodName=*/"reifyReturnTypeShapesPerResultDim",
/*args=*/(ins "::mlir::OpBuilder&":$builder,
"::mlir::SmallVectorImpl<SmallVector<::mlir::Value>>&"
:$reifiedReturnShapes),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{ return ::mlir::failure(); }]
>
];
}
@ -129,6 +161,7 @@ class InferTensorType<list<string> overridenMethods = []> {
NativeOpTrait<"InferTensorType">
];
}
defvar InferTensorTypeWithReify = InferTensorType<["reifyReturnTypeShapes"]>;
defvar InferTensorTypeWithReify = InferTensorType<[
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
#endif // MLIR_INFERTYPEOPINTERFACE

View File

@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalg
LINK_LIBS PUBLIC
MLIRAffine
MLIRDialectUtils
MLIRInferTypeOpInterface
MLIRIR
MLIRParser
MLIRSideEffectInterfaces

View File

@ -188,7 +188,7 @@ SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
for (Value v : getShapedOperands()) {
ShapedType t = v.getType().template cast<ShapedType>();
for (unsigned i = 0, e = t.getRank(); i < e; ++i)
res.push_back(b.create<memref::DimOp>(loc, v, i));
res.push_back(b.createOrFold<memref::DimOp>(loc, v, i));
}
return res;
}
@ -234,57 +234,58 @@ private:
llvm::SmallSet<unsigned, 4> positions;
};
Optional<Value> LinalgOp::inferResultDimFromInputShapes(OpBuilder &b,
Location loc,
unsigned resultIdx,
unsigned dim) {
LogicalResult LinalgOp::reifyReturnTypeShapesPerResultDim(
OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
// An example that helps understand the logic below.
// Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
// We want to express the shape of dim 0 of O in terms of shape of the inputs.
// This is achieved as follows.
// loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
// subMapOfResultDim = (d0, d1, d2) -> (d0 + d1)
// subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
// shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
// resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap)
// = (d0, d1, d2, d3, d4, d5) -> (d0 + d1)
// resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
// = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
AffineMap loopsToShapesMap = getLoopsToShapesMap();
// Find the position in the above map that represents the shape of the
// result:dim being inferred.
Optional<unsigned> resultDimSubMapPos =
getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim);
if (!resultDimSubMapPos)
return {};
auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap();
/// From loopsToShapesMap extract the submap that represents the shape of the
/// (resultIdx, dim) needed
AffineMap loopToResultDimShapeMap =
loopsToShapesMap.getSubMap(*resultDimSubMapPos);
AffineMap operandShapesToResultDimMap =
loopToResultDimShapeMap.compose(getShapesToLoopsMap());
/// (resultIdx, dim) needed.
SmallVector<unsigned, 4> resultPosRange =
llvm::to_vector<4>(llvm::seq<unsigned>(resultShapesSubMapPos.first,
resultShapesSubMapPos.second));
AffineMap loopToResultsShapeMap = loopsToShapesMap.getSubMap(resultPosRange);
AffineMap resultShapesFromInputShapesMap =
loopToResultsShapeMap.compose(getShapesToLoopsMap());
// Check that the result dim map does not contain the positions corresponding
// to the outputs.
llvm::SmallSet<unsigned, 4> outputDims;
unsigned outputDimPosStart =
getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue();
unsigned outputDimPosEnd =
getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1,
getOutputOpOperands()
.back()
.get()
.getType()
.cast<ShapedType>()
.getRank() -
1)
.getValue();
llvm::for_each(llvm::seq<unsigned>(outputDimPosStart, outputDimPosEnd),
llvm::for_each(resultPosRange,
[&outputDims](unsigned dim) { outputDims.insert(dim); });
HasAffineDimExprVisitor checkDimExpr(outputDims);
if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0)))
return llvm::None;
return applyMapToValues(b, loc, operandShapesToResultDimMap,
createFlatListOfOperandDims(b, loc))[0];
Location loc = getOperation()->getLoc();
auto allResultDimValues =
applyMapToValues(b, loc, resultShapesFromInputShapesMap,
createFlatListOfOperandDims(b, loc));
unsigned pos = 0;
ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
for (auto resultIdx : llvm::seq<unsigned>(0, getNumOutputs())) {
ShapedType resultType = getOutputShapedType(resultIdx);
SmallVector<Value> shapes;
for (unsigned dim : llvm::seq<unsigned>(0, resultType.getRank())) {
if (checkDimExpr.visit(shapeExprs[pos]))
shapes.push_back(
b.createOrFold<memref::DimOp>(loc, getOutput(resultIdx), dim));
else
shapes.push_back(allResultDimValues[pos]);
pos++;
}
reifiedReturnShapes.emplace_back(std::move(shapes));
}
return success();
}
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {

View File

@ -21,6 +21,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Parser.h"
#include "llvm/ADT/DenseMap.h"
@ -88,6 +89,33 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
template <typename NamedStructuredOpType>
static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
/// Helper function to convert a Value into an OpFoldResult, if the Value is
/// known to be a constant index value.
static SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) {
return llvm::to_vector<4>(
llvm::map_range(values, [](Value v) -> OpFoldResult {
APInt intValue;
if (v.getType().isa<IndexType>() &&
matchPattern(v, m_ConstantInt(&intValue))) {
return IntegerAttr::get(v.getType(), intValue.getSExtValue());
}
return v;
}));
}
/// Helper function to convert a vector of `OpFoldResult`s into a vector of
/// `Value`s.
static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
ArrayRef<OpFoldResult> valueOrAttrVec) {
return llvm::to_vector<4>(
llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
if (auto attr = value.dyn_cast<Attribute>())
return b.create<ConstantIndexOp>(loc,
attr.cast<IntegerAttr>().getInt());
return value.get<Value>();
}));
}
/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
/// it is a Value or into `staticVec` if it is an IntegerAttr.
/// In the case of a Value, a copy of the `sentinel` value is also pushed to
@ -679,10 +707,6 @@ void InitTensorOp::build(OpBuilder &b, OperationState &result,
SmallVector<Value, 4> dynamicSizes;
SmallVector<int64_t, 4> staticSizes;
for (unsigned i = 0; i < rank; ++i) {
// staticLow and staticHigh have full information of the padding config.
// This will grow staticLow and staticHigh with 1 value. If the config is
// dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
// value as well.
dispatchIndexOpFoldResult(sizes[i], dynamicSizes, staticSizes,
ShapedType::kDynamicSize);
}
@ -771,33 +795,6 @@ struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
return success();
}
};
/// Canonicalize a `linalg.init_tensor` -> `dim` pattern by replacing the `dim`
/// with
/// - A constant value if the size is static along the dimension.
/// - The dynamic value that defines the size of the result of
/// `linalg.init_tensor` op.
struct ReplaceDimOfInitTensorOp : public OpRewritePattern<memref::DimOp> {
using OpRewritePattern<memref::DimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(memref::DimOp dimOp,
PatternRewriter &rewriter) const override {
auto initTensorOp = dimOp.memrefOrTensor().getDefiningOp<InitTensorOp>();
if (!initTensorOp)
return failure();
auto dimIndex = dimOp.index().getDefiningOp<ConstantIndexOp>();
if (!dimIndex)
return failure();
int64_t index = dimIndex.getValue();
if (!initTensorOp.isDynamicSize(index)) {
rewriter.replaceOpWithNewOp<ConstantIndexOp>(
dimOp, initTensorOp.getStaticSize(index));
} else {
rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(index));
}
return success();
}
};
} // namespace
namespace {
@ -831,12 +828,20 @@ struct FoldInitTensorWithTensorReshapeOp
if (!reshapeOp.src().getDefiningOp<InitTensorOp>())
return failure();
Location loc = reshapeOp.getLoc();
SmallVector<Value, 4> resultShapeValues =
reshapeOp.getOutputShape(rewriter, loc);
SmallVector<SmallVector<Value>, 4> resultShapes;
if (failed(reshapeOp.reifyReturnTypeShapesPerResultDim(rewriter,
resultShapes)) ||
!llvm::hasSingleElement(resultShapes))
return failure();
Value initTensor = rewriter.create<InitTensorOp>(
loc, resultShapeValues, reshapeOp.getResultType().getElementType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
reshapeOp, reshapeOp.getResultType(), initTensor);
loc, getAsOpFoldResult(resultShapes[0]),
reshapeOp.getResultType().getElementType());
if (initTensor.getType() != reshapeOp.getResultType()) {
rewriter.replaceOpWithNewOp<tensor::CastOp>(
reshapeOp, reshapeOp.getResultType(), initTensor);
} else {
rewriter.replaceOp(reshapeOp, initTensor);
}
return success();
}
};
@ -845,7 +850,20 @@ struct FoldInitTensorWithTensorReshapeOp
void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldInitTensorWithSubTensorOp, FoldInitTensorWithTensorReshapeOp,
ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context);
ReplaceStaticShapeDims>(context);
}
LogicalResult InitTensorOp::reifyReturnTypeShapesPerResultDim(
OpBuilder &builder,
SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
auto shapes = llvm::to_vector<4>(llvm::map_range(
llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
if (isDynamicSize(dim))
return getDynamicSize(dim);
return builder.create<ConstantIndexOp>(getLoc(), getStaticSize(dim));
}));
reifiedReturnShapes.emplace_back(std::move(shapes));
return success();
}
//===----------------------------------------------------------------------===//
@ -997,6 +1015,37 @@ PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad,
builder);
}
LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim(
OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
Location loc = getLoc();
auto lowPad = getMixedLowPad();
auto highPad = getMixedHighPad();
SmallVector<Value> shapes;
for (auto dim : llvm::seq<int64_t>(0, getSourceType().getRank())) {
// Shape along each dimension is source dim + low pad + high pad.
SmallVector<Value> mapOperands;
mapOperands.push_back(b.createOrFold<memref::DimOp>(loc, source(), dim));
AffineExpr expr = b.getAffineDimExpr(0);
unsigned numSymbols = 0;
auto addOpFoldResult = [&](OpFoldResult valueOrAttr) {
if (Value v = valueOrAttr.dyn_cast<Value>()) {
expr = expr + b.getAffineSymbolExpr(numSymbols++);
mapOperands.push_back(v);
return;
}
int64_t staticValue =
valueOrAttr.get<Attribute>().cast<IntegerAttr>().getInt();
expr = expr + staticValue;
};
addOpFoldResult(lowPad[dim]);
addOpFoldResult(highPad[dim]);
shapes.push_back(applyMapToValues(
b, loc, AffineMap::get(1, numSymbols, expr), mapOperands)[0]);
}
reifiedReturnShapes.emplace_back(std::move(shapes));
return success();
}
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
@ -1281,7 +1330,7 @@ convertReassociationIndicesToMaps(
/// terms of shape of the `src`, when the reshape op is a collapsing
/// operation. It is the product of the shape of the collapsed dimensions of the
/// `src`.
static Value
static OpFoldResult
getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc,
int64_t dimIndex, Value src,
ArrayRef<AffineMap> reassociationMap) {
@ -1292,7 +1341,7 @@ getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc,
AffineExpr expr;
SmallVector<Value, 2> dynamicDims;
for (auto dim : llvm::seq(startPos, endPos + 1)) {
dynamicDims.push_back(builder.create<memref::DimOp>(loc, src, dim));
dynamicDims.push_back(builder.createOrFold<memref::DimOp>(loc, src, dim));
AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos);
expr = (expr ? expr * currExpr : currExpr);
}
@ -1303,7 +1352,7 @@ getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc,
/// Given the `src` of a collapsing reshape op and its reassociation maps,
/// compute the shape of the result of the reshape.
static SmallVector<Value, 4> getCollapsedOutputShapeFromInputShape(
static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
OpBuilder &builder, Location loc, Value src,
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
return llvm::to_vector<4>(llvm::map_range(
@ -1333,12 +1382,12 @@ getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
/// For an expanding reshape op, compute the value for a dimension of the output
/// from the shape of the input.
static Value getExpandedOutputDimFromInputShape(
static OpFoldResult getExpandedOutputDimFromInputShape(
OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
return builder.create<ConstantIndexOp>(loc, dstStaticShape[dimIndex]);
return builder.getI64IntegerAttr(dstStaticShape[dimIndex]);
}
unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
unsigned startPos = reassociation[sourceDimPos]
@ -1371,7 +1420,7 @@ static Value getExpandedOutputDimFromInputShape(
/// Given the `src` of an expanding reshape op, the reassociation maps and the
/// result type, compute the shape of the result of the reshape.
static SmallVector<Value, 4> getExpandedOutputShapeFromInputShape(
static SmallVector<OpFoldResult, 4> getExpandedOutputShapeFromInputShape(
OpBuilder &builder, Location loc, Value src,
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
@ -1384,9 +1433,10 @@ static SmallVector<Value, 4> getExpandedOutputShapeFromInputShape(
}));
}
SmallVector<Value, 4> mlir::linalg::getReshapeOutputShapeFromInputShape(
OpBuilder &builder, Location loc, Value src,
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassocation) {
static SmallVector<OpFoldResult, 4>
getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
ArrayRef<int64_t> dstStaticShape,
ArrayRef<AffineMap> reassocation) {
return dstStaticShape.size() >
static_cast<size_t>(src.getType().cast<ShapedType>().getRank())
? getExpandedOutputShapeFromInputShape(
@ -1395,23 +1445,6 @@ SmallVector<Value, 4> mlir::linalg::getReshapeOutputShapeFromInputShape(
builder, loc, src, dstStaticShape, reassocation);
}
/// For a reshape op, compute the value of a given dimension of the output
/// (`dimIndex`) from the shape of the inputs and type of the result.
static Value getReshapeOutputDimFromInputShape(
OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
if (dstStaticShape.size() >
static_cast<size_t>(src.getType().cast<ShapedType>().getRank())) {
llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
getExpandedDimToCollapsedDimMap(reassociation);
return getExpandedOutputDimFromInputShape(builder, loc, dimIndex, src,
dstStaticShape, reassociation,
expandedDimToCollapsedDim);
}
return getCollapsedOutputDimFromInputShape(builder, loc, dimIndex, src,
reassociation);
}
void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result,
Value src,
ArrayRef<ReassociationExprs> reassociation,
@ -1636,29 +1669,6 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
}
};
/// Canonicalize dim ops that use the output shape with dim of the input.
struct ReplaceDimOfReshapeOpResult : OpRewritePattern<memref::DimOp> {
using OpRewritePattern<memref::DimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(memref::DimOp dimOp,
PatternRewriter &rewriter) const override {
Value dimValue = dimOp.memrefOrTensor();
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
if (!dimIndex)
return failure();
auto reshapeOp = dimValue.getDefiningOp<TensorReshapeOp>();
if (!reshapeOp)
return failure();
rewriter.replaceOp(dimOp,
getReshapeOutputDimFromInputShape(
rewriter, dimOp.getLoc(), *dimIndex, reshapeOp.src(),
reshapeOp.getResultType().getShape(),
reshapeOp.getReassociationMaps()));
return success();
}
};
/// Fold linalg.fill -> linalg.tensor_reshape chain.
///
/// For such op chains, we can create new linalg.fill ops with the result
@ -1684,7 +1694,18 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
void TensorReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CollapseReshapeOps<TensorReshapeOp>, FoldFillWithTensorReshape,
FoldReshapeWithConstant, ReplaceDimOfReshapeOpResult>(context);
FoldReshapeWithConstant>(context);
}
LogicalResult TensorReshapeOp::reifyReturnTypeShapesPerResultDim(
OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
auto resultShape =
getAsValues(b, getLoc(),
getReshapeOutputShapeFromInputShape(
b, getLoc(), src(), getResultType().getShape(),
getReassociationMaps()));
reifiedReturnShapes.emplace_back(std::move(resultShape));
return success();
}
//===----------------------------------------------------------------------===//
@ -2544,50 +2565,6 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
return success();
}
};
/// Replaces memref.dim operations that use the result of a LinalgOp (on
/// tensors) with memref.dim operations that use one of the arguments. For
/// example,
///
/// %0 = linalg.matmul ins(%arg0, %arg1, ...)
/// %1 = memref.dim %0, %c0
///
/// with
///
/// %1 = memref.dim %arg0, %c0
///
/// where possible. With this the result of the `linalg.matmul` is not used in
/// dim operations. If the value produced is replaced with another value (say by
/// tiling `linalg.matmul`) will make the `linalg.matmul` truly dead instead of
/// used in a dim op that would prevent the DCE of this op.
struct ReplaceDimOfLinalgOpResult : public OpRewritePattern<memref::DimOp> {
using OpRewritePattern<memref::DimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(memref::DimOp dimOp,
PatternRewriter &rewriter) const override {
Value dimValue = dimOp.memrefOrTensor();
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
if (!dimIndex)
return failure();
auto linalgOp = dimValue.getDefiningOp<LinalgOp>();
if (!linalgOp)
return failure();
unsigned resultIndex = dimValue.cast<OpResult>().getResultNumber();
Optional<Value> operandDimValue = linalgOp.inferResultDimFromInputShapes(
rewriter, dimOp.getLoc(), resultIndex,
static_cast<unsigned>(*dimIndex));
if (!operandDimValue) {
// Its always possible to replace using the corresponding `outs`
// parameter.
operandDimValue = rewriter.create<memref::DimOp>(
dimOp.getLoc(), linalgOp.getOutput(resultIndex), *dimIndex);
}
rewriter.replaceOp(dimOp, *operandDimValue);
return success();
}
};
} // namespace
namespace {
@ -2745,7 +2722,7 @@ struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
void XXX::getCanonicalizationPatterns(RewritePatternSet &results, \
MLIRContext *context) { \
results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp, \
RemoveIdentityLinalgOps, ReplaceDimOfLinalgOpResult>(context); \
RemoveIdentityLinalgOps>(context); \
} \
\
LogicalResult XXX::fold(ArrayRef<Attribute>, \

View File

@ -16,6 +16,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "llvm/ADT/STLExtras.h"
using namespace mlir;
@ -673,12 +674,84 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
return success();
}
};
/// Helper method to get the `Value` that is the shape of the `resultIdx`-th
/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
/// TODO(ravishankarm): This is better put as a interface utility method
/// somewhere, but that would imply the interface will depend on the `tensor`
/// dialect. Ideally maybe a utility method in the `tensor` dialect.
static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
int64_t dimIndex) {
unsigned resultNumber = result.getResultNumber();
auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
Location loc = result.getOwner()->getLoc();
if (!shapedTypeOp)
return nullptr;
// The interface exposes two methods, one that returns the shape of all the
// results as `Value` and other that returns the shape as a list of
// `SmallVector<Value>`. The former takes precedence over the latter. So first
// check if the op implements the first interface method or the second, and
// get the value to use appropriately.
SmallVector<Value> reifiedResultShapes;
if (succeeded(
shapedTypeOp.reifyReturnTypeShapes(builder, reifiedResultShapes))) {
if (reifiedResultShapes.size() <= resultNumber)
return nullptr;
Value resultShape = reifiedResultShapes[resultNumber];
auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
return nullptr;
return builder.create<tensor::ExtractOp>(
loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
}
SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
builder, reifiedResultShapesPerDim)))
return nullptr;
if (reifiedResultShapesPerDim.size() <= resultNumber ||
reifiedResultShapesPerDim[resultNumber].size() !=
static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
return nullptr;
OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
if (auto attr = valueOrAttr.dyn_cast<Attribute>())
return builder.createOrFold<ConstantIndexOp>(
loc, attr.cast<IntegerAttr>().getInt());
return valueOrAttr.get<Value>();
}
/// Fold dim of an operation that implements the InferShapedTypeOpInterface
struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DimOp dimOp,
PatternRewriter &rewriter) const override {
OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
if (!dimValue)
return failure();
auto shapedTypeOp =
dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
if (!shapedTypeOp)
return failure();
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
if (!dimIndex)
return failure();
Value replacement =
getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
if (!replacement)
return failure();
rewriter.replaceOp(dimOp, replacement);
return success();
}
};
} // end anonymous namespace.
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
DimOfCastOp<tensor::CastOp>>(context);
DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context);
}
// ---------------------------------------------------------------------------

View File

@ -12,7 +12,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/IR/BuiltinTypes.h"
using namespace mlir;

View File

@ -404,12 +404,13 @@ func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) {
func @remove_dim_result_uses
(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) -> (index) {
%arg2 : tensor<?x?xf32>) -> (index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = linalg.generic
{indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0 + d1, d1)>],
affine_map<(d0, d1, d2) -> (d0 + d1, d1 - d0)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2 : tensor<?x?xf32>) {
@ -419,9 +420,11 @@ func @remove_dim_result_uses
linalg.yield %2 : f32
} -> tensor<?x?xf32>
%3 = memref.dim %0, %c0 : tensor<?x?xf32>
return %3 : index
%4 = memref.dim %0, %c1 : tensor<?x?xf32>
return %3, %4 : index, index
}
// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
// CHECK: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (-s0 + s1)>
// CHECK: func @remove_dim_result_uses
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@ -430,8 +433,11 @@ func @remove_dim_result_uses
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[T2:.+]] = affine.apply #[[MAP]]()[%[[T0]], %[[T1]]]
// CHECK: return %[[T2]]
// CHECK: %[[T2:.+]] = affine.apply #[[MAP0]]()[%[[T0]], %[[T1]]]
// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[T4:.+]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[T5:.+]] = affine.apply #[[MAP1]]()[%[[T3]], %[[T4]]]
// CHECK: return %[[T2]], %[[T5]]
// -----
@ -861,3 +867,38 @@ func @fold_tiled_loop_results(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
// CHECK-SAME: outs (%[[C]]:memref<192x192xf32>) {
// CHECK-NEXT: call @foo(%[[A]], %[[B]], %[[C]])
// CHECK-NEXT: linalg.yield
// -----
func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index,
%arg3: f32) -> (index, index, index)
{
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%c3 = constant 3 : index
%c4 = constant 4 : index
%c5 = constant 5 : index
%0 = linalg.pad_tensor %arg0 low[%c3, %arg1, %c4] high[7, %c5, %arg2] {
^bb0(%arg4: index, %arg5: index, %arg6: index):
linalg.yield %arg3 : f32
} : tensor<2x?x?xf32> to tensor<?x?x?xf32>
%1 = memref.dim %0, %c0 : tensor<?x?x?xf32>
%2 = memref.dim %0, %c1 : tensor<?x?x?xf32>
%3 = memref.dim %0, %c2 : tensor<?x?x?xf32>
return %1, %2, %3 : index, index, index
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 4)>
// CHECK: func @dim_of_pad_op
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<2x?x?xf32>
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[A-Za-z0-9_]+]]: index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
// CHECK-DAG: %[[C12:.+]] = constant 12 : index
// CHECK: %[[IN_DIM1:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[IN_DIM1]]]
// CHECK: %[[IN_DIM2:.+]] = memref.dim %[[ARG0]], %[[C2]]
// CHECK: %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]]
// CHECK: return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]]

View File

@ -81,3 +81,27 @@ func @typemismatch() -> i32 {
%0 = "test.passthrough_fold"(%c42) : (f32) -> (i32)
return %0 : i32
}
// CHECK-LABEL: func @result_shape_per_dim
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
-> (index, index, index, index, index) {
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
// CHECK-DAG: %[[C5:.+]] = constant 5 : index
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1)
: (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
%1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
%2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
%3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
%4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
%5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
return %1, %2, %3, %4, %5 : index, index, index, index, index
}

View File

@ -754,6 +754,25 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
return success();
}
LogicalResult
OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
OpBuilder &builder,
llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
SmallVector<Value> operand1Shape, operand2Shape;
Location loc = getLoc();
for (auto i :
llvm::seq<int>(0, operand1().getType().cast<ShapedType>().getRank())) {
operand1Shape.push_back(builder.create<memref::DimOp>(loc, operand1(), i));
}
for (auto i :
llvm::seq<int>(0, operand2().getType().cast<ShapedType>().getRank())) {
operand2Shape.push_back(builder.create<memref::DimOp>(loc, operand2(), i));
}
shapes.emplace_back(std::move(operand2Shape));
shapes.emplace_back(std::move(operand1Shape));
return success();
}
//===----------------------------------------------------------------------===//
// Test SideEffect interfaces
//===----------------------------------------------------------------------===//

View File

@ -549,7 +549,8 @@ def IndexElementsAttrOp : TEST_Op<"indexElementsAttr"> {
}
def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
DeclareOpInterfaceMethods<InferTypeOpInterface,
["inferReturnTypeComponents"]>]> {
let arguments = (ins AnyTensor, AnyTensor);
let results = (outs AnyTensor);
}
@ -560,6 +561,13 @@ def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_ty
let results = (outs AnyTensor);
}
def OpWithResultShapePerDimInterfaceOp : TEST_Op<"op_with_result_shape_per_dim_interface",
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapesPerResultDim"]>]> {
let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
}
def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
def UpdateAttr : Pat<(I32ElementsAttrOp $attr),

View File

@ -128,11 +128,13 @@ static void reifyReturnShape(Operation *op) {
// Use permutations of 2 args as operands.
auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
SmallVector<Value, 2> shapes;
if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)))
if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)) ||
!llvm::hasSingleElement(shapes))
return;
for (auto it : llvm::enumerate(shapes))
for (auto it : llvm::enumerate(shapes)) {
op->emitRemark() << "value " << it.index() << ": "
<< it.value().getDefiningOp();
}
}
struct TestReturnTypeDriver