forked from OSchip/llvm-project
[mlir] Enhance InferShapedTypeOpInterface and move LinalgOps to use them.
A new `InterfaceMethod` is added to `InferShapedTypeOpInterface` that allows an operation to return the `Value`s for each dim of its results. It is intended for the case where the `Value` returned for each dim is computed using the operands and operation attributes. This interface method is for cases where the result dim of an operation can be computed independently, and it avoids the need to aggregate all dims of a result into a single shape value. This also implies that this is not suitable for cases where the result type is unranked (for which the existing interface methods is to be used). Also added is a canonicalization pattern that uses this interface and resolves the shapes of the output in terms of the shapes of the inputs. Moving Linalg ops to use this interface, so that many canonicalization patterns implemented for individual linalg ops to achieve the same result can be removed in favor of the added canonicalization pattern. Differential Revision: https://reviews.llvm.org/D97887
This commit is contained in:
parent
742f663705
commit
9b0517035f
|
@ -1087,18 +1087,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
|||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the position in the results of the affine map computed
|
||||
by getLoopsToShapesMap() that represents the shape of the
|
||||
result value at a dimension.
|
||||
Return the range of position in the result of the affine map
|
||||
computed by getLoopsToShapesMap() which correspond to the
|
||||
AffineExprs used to access the outputs of the operation.
|
||||
}],
|
||||
/*retTy=*/"Optional<unsigned>",
|
||||
/*methodName=*/"getResultValueDimPositionInLoopsToShapeMap",
|
||||
/*args=*/(ins "unsigned":$resultIdx, "unsigned":$dim),
|
||||
/*retTy=*/"std::pair<unsigned, unsigned>",
|
||||
/*methodName=*/"getResultsPositionInLoopsToShapeMap",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
if (resultIdx >= getNumOutputs()) return {};
|
||||
return getOperandDimPositionInLoopsToShapeMap(
|
||||
getNumInputs() + resultIdx, dim);
|
||||
return
|
||||
{*getOperandDimPositionInLoopsToShapeMap(getNumInputs(), 0),
|
||||
(*getOperandDimPositionInLoopsToShapeMap
|
||||
(getNumInputs() + getNumOutputs() - 1,
|
||||
getOutputShapedType(getNumOutputs()-1).getRank() - 1)) + 1};
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
|
@ -1226,8 +1228,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
|||
|
||||
/// Returns the value that expresses the shape of the output in terms of
|
||||
/// shape of the input operands where possible
|
||||
Optional<Value> inferResultDimFromInputShapes
|
||||
(OpBuilder &b, Location loc, unsigned resultIdx, unsigned im);
|
||||
LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,
|
||||
SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes);
|
||||
|
||||
//========================================================================//
|
||||
// Helper functions to mutate the `operand_segment_sizes` attribute.
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/CopyOpInterface.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
@ -107,13 +108,6 @@ SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
|
|||
void getDimsOfType(Operation *op, StringRef iteratorTypeName,
|
||||
SmallVectorImpl<AffineExpr> &res);
|
||||
|
||||
/// For reshape operation, compute the shape of the output based on the result
|
||||
/// type and shape of the input.
|
||||
SmallVector<Value, 4>
|
||||
getReshapeOutputShapeFromInputShape(OpBuilder &b, Location loc, Value src,
|
||||
ArrayRef<int64_t> dstStaticShape,
|
||||
ArrayRef<AffineMap> reassociation);
|
||||
|
||||
namespace detail {
|
||||
LogicalResult verifyStructuredOpInterface(Operation *op);
|
||||
} // namespace detail
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/LoopLikeInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||
|
@ -33,7 +34,10 @@ class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> {
|
||||
def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
|
||||
[NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["reifyReturnTypeShapesPerResultDim"]>]> {
|
||||
let summary = "operation to define a tensor of particular value";
|
||||
|
||||
let description = [{
|
||||
|
@ -126,7 +130,10 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> {
|
|||
}
|
||||
|
||||
def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
|
||||
[AttrSizedOperandSegments, NoSideEffect]> {
|
||||
[AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["reifyReturnTypeShapesPerResultDim"]>,
|
||||
NoSideEffect]> {
|
||||
let summary = "tensor pad operation";
|
||||
let description = [{
|
||||
`linalg.pad_tensor` is an operation that pads the `source` tensor
|
||||
|
@ -348,11 +355,6 @@ class Linalg_ReshapeLikeOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
a.cast<AffineMapAttr>().getValue().getResults());
|
||||
}));
|
||||
}
|
||||
SmallVector<Value, 4> getOutputShape(OpBuilder &b, Location loc) {
|
||||
return getReshapeOutputShapeFromInputShape(
|
||||
b, loc, src(), getResultType().getShape(),
|
||||
getReassociationMaps());
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
$src $reassociation attr-dict `:` type($src) `into` type(results)
|
||||
|
@ -417,7 +419,10 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",
|
|||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
|
||||
def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<
|
||||
"tensor_reshape",
|
||||
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["reifyReturnTypeShapesPerResultDim"]>]>,
|
||||
Arguments<(ins AnyTensor:$src,
|
||||
AffineMapArrayAttr:$reassociation)>,
|
||||
Results<(outs AnyTensor:$result)> {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
|
||||
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
|
||||
include "mlir/Interfaces/CopyOpInterface.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
// Base Tablegen class for Linalg ops.
|
||||
|
@ -25,7 +26,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
|
|||
// depending on the specific Linalg op.
|
||||
class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
|
||||
: Op<Linalg_Dialect, mnemonic, !listconcat(props, [
|
||||
LinalgStructuredInterface])> {
|
||||
LinalgStructuredInterface, InferShapedTypeOpInterface])> {
|
||||
code structuredOpsBaseDecls = [{
|
||||
// Return the number of induction variables in the basic block. This should
|
||||
// always be 0 for index-free linalg ops. For IndexedGeneric, this must be
|
||||
|
@ -33,6 +34,12 @@ class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
|
|||
unsigned getNumPayloadInductionVariables() {
|
||||
return isa<IndexedGenericOp>(this->getOperation()) ? getNumLoops() : 0;
|
||||
}
|
||||
|
||||
LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,
|
||||
SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
|
||||
return cast<LinalgOp>(getOperation()).reifyReturnTypeShapesPerResultDim(b,
|
||||
reifiedReturnShapes);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -97,21 +97,53 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
|
|||
"::mlir::DictionaryAttr":$attributes,
|
||||
"::mlir::RegionRange":$regions,
|
||||
"::mlir::SmallVectorImpl<::mlir::ShapedTypeComponents>&":
|
||||
$inferredReturnShapes)
|
||||
$inferredReturnShapes),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImplementation=*/[{ return ::mlir::failure(); }]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{Reify the shape computation for the operation.
|
||||
|
||||
Insert operations using the given OpBuilder that computes the result
|
||||
shape.
|
||||
Insert operations using the given OpBuilder that computes the
|
||||
result shape. Only one of this method or
|
||||
`reifyReturnTypeShapesPerResultDim` needs to be overriden by the
|
||||
operation.
|
||||
}],
|
||||
/*retTy=*/"::mlir::LogicalResult",
|
||||
/*methodName=*/"reifyReturnTypeShapes",
|
||||
/*args=*/(ins "::mlir::OpBuilder&":$builder,
|
||||
"::mlir::SmallVectorImpl<::mlir::Value>&":$reifiedReturnShapes),
|
||||
"::mlir::SmallVectorImpl<Value> &":$reifiedReturnShapes),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImplementation=*/[{ return ::mlir::failure(); }]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{Reify the shape computation for the operation.
|
||||
|
||||
Insert operations using the given OpBuilder that computes the
|
||||
result shape. The `reifiedReturnShapes` is expected to be
|
||||
populated with as many vectors as the number of results of the
|
||||
op (empty if the shape of a result value cannot be computed). If
|
||||
the returned shape for a result is not empty, its size must
|
||||
match the rank of the shaped type returned. Consequently, this
|
||||
interface can only be overridden if the return types are ranked.
|
||||
|
||||
If both this method and `reifyReturnTypeShapes` are overridden
|
||||
by the operation, `reifyReturnTypeShapes` takes precedence. This
|
||||
method is intended to be used when the shape of each result, dim
|
||||
pair can be computed independently. Using this method avoids
|
||||
adding additional instructions to aggregate individual dimension
|
||||
of a result shape into an single `Value` (and consequently
|
||||
avoids the need to extract the value from the shape on the
|
||||
client side).
|
||||
}],
|
||||
/*retTy=*/"::mlir::LogicalResult",
|
||||
/*methodName=*/"reifyReturnTypeShapesPerResultDim",
|
||||
/*args=*/(ins "::mlir::OpBuilder&":$builder,
|
||||
"::mlir::SmallVectorImpl<SmallVector<::mlir::Value>>&"
|
||||
:$reifiedReturnShapes),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImplementation=*/[{ return ::mlir::failure(); }]
|
||||
>
|
||||
];
|
||||
}
|
||||
|
||||
|
@ -129,6 +161,7 @@ class InferTensorType<list<string> overridenMethods = []> {
|
|||
NativeOpTrait<"InferTensorType">
|
||||
];
|
||||
}
|
||||
defvar InferTensorTypeWithReify = InferTensorType<["reifyReturnTypeShapes"]>;
|
||||
defvar InferTensorTypeWithReify = InferTensorType<[
|
||||
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
|
||||
|
||||
#endif // MLIR_INFERTYPEOPINTERFACE
|
||||
|
|
|
@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalg
|
|||
LINK_LIBS PUBLIC
|
||||
MLIRAffine
|
||||
MLIRDialectUtils
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRIR
|
||||
MLIRParser
|
||||
MLIRSideEffectInterfaces
|
||||
|
|
|
@ -188,7 +188,7 @@ SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
|
|||
for (Value v : getShapedOperands()) {
|
||||
ShapedType t = v.getType().template cast<ShapedType>();
|
||||
for (unsigned i = 0, e = t.getRank(); i < e; ++i)
|
||||
res.push_back(b.create<memref::DimOp>(loc, v, i));
|
||||
res.push_back(b.createOrFold<memref::DimOp>(loc, v, i));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
@ -234,57 +234,58 @@ private:
|
|||
llvm::SmallSet<unsigned, 4> positions;
|
||||
};
|
||||
|
||||
Optional<Value> LinalgOp::inferResultDimFromInputShapes(OpBuilder &b,
|
||||
Location loc,
|
||||
unsigned resultIdx,
|
||||
unsigned dim) {
|
||||
LogicalResult LinalgOp::reifyReturnTypeShapesPerResultDim(
|
||||
OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
|
||||
// An example that helps understand the logic below.
|
||||
// Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
|
||||
// We want to express the shape of dim 0 of O in terms of shape of the inputs.
|
||||
// This is achieved as follows.
|
||||
// loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
|
||||
// subMapOfResultDim = (d0, d1, d2) -> (d0 + d1)
|
||||
// subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
|
||||
// shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
|
||||
// resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap)
|
||||
// = (d0, d1, d2, d3, d4, d5) -> (d0 + d1)
|
||||
// resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
|
||||
// = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
|
||||
AffineMap loopsToShapesMap = getLoopsToShapesMap();
|
||||
|
||||
// Find the position in the above map that represents the shape of the
|
||||
// result:dim being inferred.
|
||||
Optional<unsigned> resultDimSubMapPos =
|
||||
getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim);
|
||||
if (!resultDimSubMapPos)
|
||||
return {};
|
||||
auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap();
|
||||
|
||||
/// From loopsToShapesMap extract the submap that represents the shape of the
|
||||
/// (resultIdx, dim) needed
|
||||
AffineMap loopToResultDimShapeMap =
|
||||
loopsToShapesMap.getSubMap(*resultDimSubMapPos);
|
||||
AffineMap operandShapesToResultDimMap =
|
||||
loopToResultDimShapeMap.compose(getShapesToLoopsMap());
|
||||
/// (resultIdx, dim) needed.
|
||||
SmallVector<unsigned, 4> resultPosRange =
|
||||
llvm::to_vector<4>(llvm::seq<unsigned>(resultShapesSubMapPos.first,
|
||||
resultShapesSubMapPos.second));
|
||||
AffineMap loopToResultsShapeMap = loopsToShapesMap.getSubMap(resultPosRange);
|
||||
AffineMap resultShapesFromInputShapesMap =
|
||||
loopToResultsShapeMap.compose(getShapesToLoopsMap());
|
||||
|
||||
// Check that the result dim map does not contain the positions corresponding
|
||||
// to the outputs.
|
||||
llvm::SmallSet<unsigned, 4> outputDims;
|
||||
unsigned outputDimPosStart =
|
||||
getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue();
|
||||
unsigned outputDimPosEnd =
|
||||
getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1,
|
||||
getOutputOpOperands()
|
||||
.back()
|
||||
.get()
|
||||
.getType()
|
||||
.cast<ShapedType>()
|
||||
.getRank() -
|
||||
1)
|
||||
.getValue();
|
||||
llvm::for_each(llvm::seq<unsigned>(outputDimPosStart, outputDimPosEnd),
|
||||
llvm::for_each(resultPosRange,
|
||||
[&outputDims](unsigned dim) { outputDims.insert(dim); });
|
||||
HasAffineDimExprVisitor checkDimExpr(outputDims);
|
||||
if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0)))
|
||||
return llvm::None;
|
||||
return applyMapToValues(b, loc, operandShapesToResultDimMap,
|
||||
createFlatListOfOperandDims(b, loc))[0];
|
||||
Location loc = getOperation()->getLoc();
|
||||
auto allResultDimValues =
|
||||
applyMapToValues(b, loc, resultShapesFromInputShapesMap,
|
||||
createFlatListOfOperandDims(b, loc));
|
||||
unsigned pos = 0;
|
||||
ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
|
||||
for (auto resultIdx : llvm::seq<unsigned>(0, getNumOutputs())) {
|
||||
ShapedType resultType = getOutputShapedType(resultIdx);
|
||||
SmallVector<Value> shapes;
|
||||
for (unsigned dim : llvm::seq<unsigned>(0, resultType.getRank())) {
|
||||
if (checkDimExpr.visit(shapeExprs[pos]))
|
||||
shapes.push_back(
|
||||
b.createOrFold<memref::DimOp>(loc, getOutput(resultIdx), dim));
|
||||
else
|
||||
shapes.push_back(allResultDimValues[pos]);
|
||||
pos++;
|
||||
}
|
||||
reifiedReturnShapes.emplace_back(std::move(shapes));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Parser.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
@ -88,6 +89,33 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
|
|||
template <typename NamedStructuredOpType>
|
||||
static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
|
||||
|
||||
/// Helper function to convert a Value into an OpFoldResult, if the Value is
|
||||
/// known to be a constant index value.
|
||||
static SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) {
|
||||
return llvm::to_vector<4>(
|
||||
llvm::map_range(values, [](Value v) -> OpFoldResult {
|
||||
APInt intValue;
|
||||
if (v.getType().isa<IndexType>() &&
|
||||
matchPattern(v, m_ConstantInt(&intValue))) {
|
||||
return IntegerAttr::get(v.getType(), intValue.getSExtValue());
|
||||
}
|
||||
return v;
|
||||
}));
|
||||
}
|
||||
|
||||
/// Helper function to convert a vector of `OpFoldResult`s into a vector of
|
||||
/// `Value`s.
|
||||
static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
|
||||
ArrayRef<OpFoldResult> valueOrAttrVec) {
|
||||
return llvm::to_vector<4>(
|
||||
llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
|
||||
if (auto attr = value.dyn_cast<Attribute>())
|
||||
return b.create<ConstantIndexOp>(loc,
|
||||
attr.cast<IntegerAttr>().getInt());
|
||||
return value.get<Value>();
|
||||
}));
|
||||
}
|
||||
|
||||
/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
|
||||
/// it is a Value or into `staticVec` if it is an IntegerAttr.
|
||||
/// In the case of a Value, a copy of the `sentinel` value is also pushed to
|
||||
|
@ -679,10 +707,6 @@ void InitTensorOp::build(OpBuilder &b, OperationState &result,
|
|||
SmallVector<Value, 4> dynamicSizes;
|
||||
SmallVector<int64_t, 4> staticSizes;
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
// staticLow and staticHigh have full information of the padding config.
|
||||
// This will grow staticLow and staticHigh with 1 value. If the config is
|
||||
// dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
|
||||
// value as well.
|
||||
dispatchIndexOpFoldResult(sizes[i], dynamicSizes, staticSizes,
|
||||
ShapedType::kDynamicSize);
|
||||
}
|
||||
|
@ -771,33 +795,6 @@ struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Canonicalize a `linalg.init_tensor` -> `dim` pattern by replacing the `dim`
|
||||
/// with
|
||||
/// - A constant value if the size is static along the dimension.
|
||||
/// - The dynamic value that defines the size of the result of
|
||||
/// `linalg.init_tensor` op.
|
||||
struct ReplaceDimOfInitTensorOp : public OpRewritePattern<memref::DimOp> {
|
||||
using OpRewritePattern<memref::DimOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::DimOp dimOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto initTensorOp = dimOp.memrefOrTensor().getDefiningOp<InitTensorOp>();
|
||||
if (!initTensorOp)
|
||||
return failure();
|
||||
auto dimIndex = dimOp.index().getDefiningOp<ConstantIndexOp>();
|
||||
if (!dimIndex)
|
||||
return failure();
|
||||
int64_t index = dimIndex.getValue();
|
||||
if (!initTensorOp.isDynamicSize(index)) {
|
||||
rewriter.replaceOpWithNewOp<ConstantIndexOp>(
|
||||
dimOp, initTensorOp.getStaticSize(index));
|
||||
} else {
|
||||
rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(index));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
@ -831,12 +828,20 @@ struct FoldInitTensorWithTensorReshapeOp
|
|||
if (!reshapeOp.src().getDefiningOp<InitTensorOp>())
|
||||
return failure();
|
||||
Location loc = reshapeOp.getLoc();
|
||||
SmallVector<Value, 4> resultShapeValues =
|
||||
reshapeOp.getOutputShape(rewriter, loc);
|
||||
SmallVector<SmallVector<Value>, 4> resultShapes;
|
||||
if (failed(reshapeOp.reifyReturnTypeShapesPerResultDim(rewriter,
|
||||
resultShapes)) ||
|
||||
!llvm::hasSingleElement(resultShapes))
|
||||
return failure();
|
||||
Value initTensor = rewriter.create<InitTensorOp>(
|
||||
loc, resultShapeValues, reshapeOp.getResultType().getElementType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
reshapeOp, reshapeOp.getResultType(), initTensor);
|
||||
loc, getAsOpFoldResult(resultShapes[0]),
|
||||
reshapeOp.getResultType().getElementType());
|
||||
if (initTensor.getType() != reshapeOp.getResultType()) {
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
reshapeOp, reshapeOp.getResultType(), initTensor);
|
||||
} else {
|
||||
rewriter.replaceOp(reshapeOp, initTensor);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -845,7 +850,20 @@ struct FoldInitTensorWithTensorReshapeOp
|
|||
void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<FoldInitTensorWithSubTensorOp, FoldInitTensorWithTensorReshapeOp,
|
||||
ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context);
|
||||
ReplaceStaticShapeDims>(context);
|
||||
}
|
||||
|
||||
LogicalResult InitTensorOp::reifyReturnTypeShapesPerResultDim(
|
||||
OpBuilder &builder,
|
||||
SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
|
||||
auto shapes = llvm::to_vector<4>(llvm::map_range(
|
||||
llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
|
||||
if (isDynamicSize(dim))
|
||||
return getDynamicSize(dim);
|
||||
return builder.create<ConstantIndexOp>(getLoc(), getStaticSize(dim));
|
||||
}));
|
||||
reifiedReturnShapes.emplace_back(std::move(shapes));
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -997,6 +1015,37 @@ PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad,
|
|||
builder);
|
||||
}
|
||||
|
||||
LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim(
|
||||
OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
|
||||
Location loc = getLoc();
|
||||
auto lowPad = getMixedLowPad();
|
||||
auto highPad = getMixedHighPad();
|
||||
SmallVector<Value> shapes;
|
||||
for (auto dim : llvm::seq<int64_t>(0, getSourceType().getRank())) {
|
||||
// Shape along each dimension is source dim + low pad + high pad.
|
||||
SmallVector<Value> mapOperands;
|
||||
mapOperands.push_back(b.createOrFold<memref::DimOp>(loc, source(), dim));
|
||||
AffineExpr expr = b.getAffineDimExpr(0);
|
||||
unsigned numSymbols = 0;
|
||||
auto addOpFoldResult = [&](OpFoldResult valueOrAttr) {
|
||||
if (Value v = valueOrAttr.dyn_cast<Value>()) {
|
||||
expr = expr + b.getAffineSymbolExpr(numSymbols++);
|
||||
mapOperands.push_back(v);
|
||||
return;
|
||||
}
|
||||
int64_t staticValue =
|
||||
valueOrAttr.get<Attribute>().cast<IntegerAttr>().getInt();
|
||||
expr = expr + staticValue;
|
||||
};
|
||||
addOpFoldResult(lowPad[dim]);
|
||||
addOpFoldResult(highPad[dim]);
|
||||
shapes.push_back(applyMapToValues(
|
||||
b, loc, AffineMap::get(1, numSymbols, expr), mapOperands)[0]);
|
||||
}
|
||||
reifiedReturnShapes.emplace_back(std::move(shapes));
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReshapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1281,7 +1330,7 @@ convertReassociationIndicesToMaps(
|
|||
/// terms of shape of the `src`, when the reshape op is a collapsing
|
||||
/// operation. It is the product of the shape of the collapsed dimensions of the
|
||||
/// `src`.
|
||||
static Value
|
||||
static OpFoldResult
|
||||
getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc,
|
||||
int64_t dimIndex, Value src,
|
||||
ArrayRef<AffineMap> reassociationMap) {
|
||||
|
@ -1292,7 +1341,7 @@ getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc,
|
|||
AffineExpr expr;
|
||||
SmallVector<Value, 2> dynamicDims;
|
||||
for (auto dim : llvm::seq(startPos, endPos + 1)) {
|
||||
dynamicDims.push_back(builder.create<memref::DimOp>(loc, src, dim));
|
||||
dynamicDims.push_back(builder.createOrFold<memref::DimOp>(loc, src, dim));
|
||||
AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos);
|
||||
expr = (expr ? expr * currExpr : currExpr);
|
||||
}
|
||||
|
@ -1303,7 +1352,7 @@ getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc,
|
|||
|
||||
/// Given the `src` of a collapsing reshape op and its reassociation maps,
|
||||
/// compute the shape of the result of the reshape.
|
||||
static SmallVector<Value, 4> getCollapsedOutputShapeFromInputShape(
|
||||
static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
|
||||
OpBuilder &builder, Location loc, Value src,
|
||||
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
|
||||
return llvm::to_vector<4>(llvm::map_range(
|
||||
|
@ -1333,12 +1382,12 @@ getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
|
|||
|
||||
/// For an expanding reshape op, compute the value for a dimension of the output
|
||||
/// from the shape of the input.
|
||||
static Value getExpandedOutputDimFromInputShape(
|
||||
static OpFoldResult getExpandedOutputDimFromInputShape(
|
||||
OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
|
||||
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
|
||||
llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
|
||||
if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
|
||||
return builder.create<ConstantIndexOp>(loc, dstStaticShape[dimIndex]);
|
||||
return builder.getI64IntegerAttr(dstStaticShape[dimIndex]);
|
||||
}
|
||||
unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
|
||||
unsigned startPos = reassociation[sourceDimPos]
|
||||
|
@ -1371,7 +1420,7 @@ static Value getExpandedOutputDimFromInputShape(
|
|||
|
||||
/// Given the `src` of an expanding reshape op, the reassociation maps and the
|
||||
/// result type, compute the shape of the result of the reshape.
|
||||
static SmallVector<Value, 4> getExpandedOutputShapeFromInputShape(
|
||||
static SmallVector<OpFoldResult, 4> getExpandedOutputShapeFromInputShape(
|
||||
OpBuilder &builder, Location loc, Value src,
|
||||
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
|
||||
llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
|
||||
|
@ -1384,9 +1433,10 @@ static SmallVector<Value, 4> getExpandedOutputShapeFromInputShape(
|
|||
}));
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> mlir::linalg::getReshapeOutputShapeFromInputShape(
|
||||
OpBuilder &builder, Location loc, Value src,
|
||||
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassocation) {
|
||||
static SmallVector<OpFoldResult, 4>
|
||||
getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
|
||||
ArrayRef<int64_t> dstStaticShape,
|
||||
ArrayRef<AffineMap> reassocation) {
|
||||
return dstStaticShape.size() >
|
||||
static_cast<size_t>(src.getType().cast<ShapedType>().getRank())
|
||||
? getExpandedOutputShapeFromInputShape(
|
||||
|
@ -1395,23 +1445,6 @@ SmallVector<Value, 4> mlir::linalg::getReshapeOutputShapeFromInputShape(
|
|||
builder, loc, src, dstStaticShape, reassocation);
|
||||
}
|
||||
|
||||
/// For a reshape op, compute the value of a given dimension of the output
|
||||
/// (`dimIndex`) from the shape of the inputs and type of the result.
|
||||
static Value getReshapeOutputDimFromInputShape(
|
||||
OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
|
||||
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
|
||||
if (dstStaticShape.size() >
|
||||
static_cast<size_t>(src.getType().cast<ShapedType>().getRank())) {
|
||||
llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
|
||||
getExpandedDimToCollapsedDimMap(reassociation);
|
||||
return getExpandedOutputDimFromInputShape(builder, loc, dimIndex, src,
|
||||
dstStaticShape, reassociation,
|
||||
expandedDimToCollapsedDim);
|
||||
}
|
||||
return getCollapsedOutputDimFromInputShape(builder, loc, dimIndex, src,
|
||||
reassociation);
|
||||
}
|
||||
|
||||
void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result,
|
||||
Value src,
|
||||
ArrayRef<ReassociationExprs> reassociation,
|
||||
|
@ -1636,29 +1669,6 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Canonicalize dim ops that use the output shape with dim of the input.
|
||||
struct ReplaceDimOfReshapeOpResult : OpRewritePattern<memref::DimOp> {
|
||||
using OpRewritePattern<memref::DimOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(memref::DimOp dimOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value dimValue = dimOp.memrefOrTensor();
|
||||
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
|
||||
if (!dimIndex)
|
||||
return failure();
|
||||
|
||||
auto reshapeOp = dimValue.getDefiningOp<TensorReshapeOp>();
|
||||
if (!reshapeOp)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(dimOp,
|
||||
getReshapeOutputDimFromInputShape(
|
||||
rewriter, dimOp.getLoc(), *dimIndex, reshapeOp.src(),
|
||||
reshapeOp.getResultType().getShape(),
|
||||
reshapeOp.getReassociationMaps()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Fold linalg.fill -> linalg.tensor_reshape chain.
|
||||
///
|
||||
/// For such op chains, we can create new linalg.fill ops with the result
|
||||
|
@ -1684,7 +1694,18 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
|
|||
void TensorReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<CollapseReshapeOps<TensorReshapeOp>, FoldFillWithTensorReshape,
|
||||
FoldReshapeWithConstant, ReplaceDimOfReshapeOpResult>(context);
|
||||
FoldReshapeWithConstant>(context);
|
||||
}
|
||||
|
||||
LogicalResult TensorReshapeOp::reifyReturnTypeShapesPerResultDim(
|
||||
OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
|
||||
auto resultShape =
|
||||
getAsValues(b, getLoc(),
|
||||
getReshapeOutputShapeFromInputShape(
|
||||
b, getLoc(), src(), getResultType().getShape(),
|
||||
getReassociationMaps()));
|
||||
reifiedReturnShapes.emplace_back(std::move(resultShape));
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2544,50 +2565,6 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Replaces memref.dim operations that use the result of a LinalgOp (on
|
||||
/// tensors) with memref.dim operations that use one of the arguments. For
|
||||
/// example,
|
||||
///
|
||||
/// %0 = linalg.matmul ins(%arg0, %arg1, ...)
|
||||
/// %1 = memref.dim %0, %c0
|
||||
///
|
||||
/// with
|
||||
///
|
||||
/// %1 = memref.dim %arg0, %c0
|
||||
///
|
||||
/// where possible. With this the result of the `linalg.matmul` is not used in
|
||||
/// dim operations. If the value produced is replaced with another value (say by
|
||||
/// tiling `linalg.matmul`) will make the `linalg.matmul` truly dead instead of
|
||||
/// used in a dim op that would prevent the DCE of this op.
|
||||
struct ReplaceDimOfLinalgOpResult : public OpRewritePattern<memref::DimOp> {
|
||||
using OpRewritePattern<memref::DimOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::DimOp dimOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value dimValue = dimOp.memrefOrTensor();
|
||||
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
|
||||
if (!dimIndex)
|
||||
return failure();
|
||||
auto linalgOp = dimValue.getDefiningOp<LinalgOp>();
|
||||
if (!linalgOp)
|
||||
return failure();
|
||||
|
||||
unsigned resultIndex = dimValue.cast<OpResult>().getResultNumber();
|
||||
Optional<Value> operandDimValue = linalgOp.inferResultDimFromInputShapes(
|
||||
rewriter, dimOp.getLoc(), resultIndex,
|
||||
static_cast<unsigned>(*dimIndex));
|
||||
if (!operandDimValue) {
|
||||
// Its always possible to replace using the corresponding `outs`
|
||||
// parameter.
|
||||
operandDimValue = rewriter.create<memref::DimOp>(
|
||||
dimOp.getLoc(), linalgOp.getOutput(resultIndex), *dimIndex);
|
||||
}
|
||||
rewriter.replaceOp(dimOp, *operandDimValue);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
@ -2745,7 +2722,7 @@ struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
|
|||
void XXX::getCanonicalizationPatterns(RewritePatternSet &results, \
|
||||
MLIRContext *context) { \
|
||||
results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp, \
|
||||
RemoveIdentityLinalgOps, ReplaceDimOfLinalgOpResult>(context); \
|
||||
RemoveIdentityLinalgOps>(context); \
|
||||
} \
|
||||
\
|
||||
LogicalResult XXX::fold(ArrayRef<Attribute>, \
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -673,12 +674,84 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Helper method to get the `Value` that is the shape of the `resultIdx`-th
|
||||
/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
|
||||
/// TODO(ravishankarm): This is better put as a interface utility method
|
||||
/// somewhere, but that would imply the interface will depend on the `tensor`
|
||||
/// dialect. Ideally maybe a utility method in the `tensor` dialect.
|
||||
static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
|
||||
int64_t dimIndex) {
|
||||
unsigned resultNumber = result.getResultNumber();
|
||||
auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
|
||||
Location loc = result.getOwner()->getLoc();
|
||||
if (!shapedTypeOp)
|
||||
return nullptr;
|
||||
|
||||
// The interface exposes two methods, one that returns the shape of all the
|
||||
// results as `Value` and other that returns the shape as a list of
|
||||
// `SmallVector<Value>`. The former takes precedence over the latter. So first
|
||||
// check if the op implements the first interface method or the second, and
|
||||
// get the value to use appropriately.
|
||||
SmallVector<Value> reifiedResultShapes;
|
||||
if (succeeded(
|
||||
shapedTypeOp.reifyReturnTypeShapes(builder, reifiedResultShapes))) {
|
||||
if (reifiedResultShapes.size() <= resultNumber)
|
||||
return nullptr;
|
||||
Value resultShape = reifiedResultShapes[resultNumber];
|
||||
auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
|
||||
if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
|
||||
return nullptr;
|
||||
return builder.create<tensor::ExtractOp>(
|
||||
loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
|
||||
if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
|
||||
builder, reifiedResultShapesPerDim)))
|
||||
return nullptr;
|
||||
if (reifiedResultShapesPerDim.size() <= resultNumber ||
|
||||
reifiedResultShapesPerDim[resultNumber].size() !=
|
||||
static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
|
||||
return nullptr;
|
||||
OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
|
||||
if (auto attr = valueOrAttr.dyn_cast<Attribute>())
|
||||
return builder.createOrFold<ConstantIndexOp>(
|
||||
loc, attr.cast<IntegerAttr>().getInt());
|
||||
return valueOrAttr.get<Value>();
|
||||
}
|
||||
|
||||
/// Fold dim of an operation that implements the InferShapedTypeOpInterface
|
||||
struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> {
|
||||
using OpRewritePattern<DimOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(DimOp dimOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
|
||||
if (!dimValue)
|
||||
return failure();
|
||||
auto shapedTypeOp =
|
||||
dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
|
||||
if (!shapedTypeOp)
|
||||
return failure();
|
||||
|
||||
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
|
||||
if (!dimIndex)
|
||||
return failure();
|
||||
Value replacement =
|
||||
getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
|
||||
if (!replacement)
|
||||
return failure();
|
||||
rewriter.replaceOp(dimOp, replacement);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
|
||||
DimOfCastOp<tensor::CastOp>>(context);
|
||||
DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
|
|
@ -404,12 +404,13 @@ func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) {
|
|||
|
||||
func @remove_dim_result_uses
|
||||
(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
|
||||
%arg2 : tensor<?x?xf32>) -> (index) {
|
||||
%arg2 : tensor<?x?xf32>) -> (index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0 + d1, d1)>],
|
||||
affine_map<(d0, d1, d2) -> (d0 + d1, d1 - d0)>],
|
||||
iterator_types = ["parallel", "parallel", "reduction"]}
|
||||
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg2 : tensor<?x?xf32>) {
|
||||
|
@ -419,9 +420,11 @@ func @remove_dim_result_uses
|
|||
linalg.yield %2 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
%3 = memref.dim %0, %c0 : tensor<?x?xf32>
|
||||
return %3 : index
|
||||
%4 = memref.dim %0, %c1 : tensor<?x?xf32>
|
||||
return %3, %4 : index, index
|
||||
}
|
||||
// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
|
||||
// CHECK: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
|
||||
// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (-s0 + s1)>
|
||||
// CHECK: func @remove_dim_result_uses
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
|
@ -430,8 +433,11 @@ func @remove_dim_result_uses
|
|||
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[T2:.+]] = affine.apply #[[MAP]]()[%[[T0]], %[[T1]]]
|
||||
// CHECK: return %[[T2]]
|
||||
// CHECK: %[[T2:.+]] = affine.apply #[[MAP0]]()[%[[T0]], %[[T1]]]
|
||||
// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK-DAG: %[[T4:.+]] = memref.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[T5:.+]] = affine.apply #[[MAP1]]()[%[[T3]], %[[T4]]]
|
||||
// CHECK: return %[[T2]], %[[T5]]
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -861,3 +867,38 @@ func @fold_tiled_loop_results(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
|
|||
// CHECK-SAME: outs (%[[C]]:memref<192x192xf32>) {
|
||||
// CHECK-NEXT: call @foo(%[[A]], %[[B]], %[[C]])
|
||||
// CHECK-NEXT: linalg.yield
|
||||
|
||||
// -----
|
||||
|
||||
func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index,
|
||||
%arg3: f32) -> (index, index, index)
|
||||
{
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%c3 = constant 3 : index
|
||||
%c4 = constant 4 : index
|
||||
%c5 = constant 5 : index
|
||||
%0 = linalg.pad_tensor %arg0 low[%c3, %arg1, %c4] high[7, %c5, %arg2] {
|
||||
^bb0(%arg4: index, %arg5: index, %arg6: index):
|
||||
linalg.yield %arg3 : f32
|
||||
} : tensor<2x?x?xf32> to tensor<?x?x?xf32>
|
||||
%1 = memref.dim %0, %c0 : tensor<?x?x?xf32>
|
||||
%2 = memref.dim %0, %c1 : tensor<?x?x?xf32>
|
||||
%3 = memref.dim %0, %c2 : tensor<?x?x?xf32>
|
||||
return %1, %2, %3 : index, index, index
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 4)>
|
||||
// CHECK: func @dim_of_pad_op
|
||||
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<2x?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]+]]: index
|
||||
// CHECK-SAME: %[[ARG2:[A-Za-z0-9_]+]]: index
|
||||
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C12:.+]] = constant 12 : index
|
||||
// CHECK: %[[IN_DIM1:.+]] = memref.dim %[[ARG0]], %[[C1]]
|
||||
// CHECK: %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[IN_DIM1]]]
|
||||
// CHECK: %[[IN_DIM2:.+]] = memref.dim %[[ARG0]], %[[C2]]
|
||||
// CHECK: %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]]
|
||||
// CHECK: return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]]
|
||||
|
|
|
@ -81,3 +81,27 @@ func @typemismatch() -> i32 {
|
|||
%0 = "test.passthrough_fold"(%c42) : (f32) -> (i32)
|
||||
return %0 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @result_shape_per_dim
|
||||
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
|
||||
func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
|
||||
-> (index, index, index, index, index) {
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
|
||||
// CHECK-DAG: %[[C5:.+]] = constant 5 : index
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1)
|
||||
: (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
|
||||
%1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
|
||||
%2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
|
||||
%3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
|
||||
%4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
|
||||
%5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
|
||||
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
|
||||
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
|
||||
// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
|
||||
return %1, %2, %3, %4, %5 : index, index, index, index, index
|
||||
}
|
|
@ -754,6 +754,25 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
|
||||
OpBuilder &builder,
|
||||
llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
|
||||
SmallVector<Value> operand1Shape, operand2Shape;
|
||||
Location loc = getLoc();
|
||||
for (auto i :
|
||||
llvm::seq<int>(0, operand1().getType().cast<ShapedType>().getRank())) {
|
||||
operand1Shape.push_back(builder.create<memref::DimOp>(loc, operand1(), i));
|
||||
}
|
||||
for (auto i :
|
||||
llvm::seq<int>(0, operand2().getType().cast<ShapedType>().getRank())) {
|
||||
operand2Shape.push_back(builder.create<memref::DimOp>(loc, operand2(), i));
|
||||
}
|
||||
shapes.emplace_back(std::move(operand2Shape));
|
||||
shapes.emplace_back(std::move(operand1Shape));
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test SideEffect interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -549,7 +549,8 @@ def IndexElementsAttrOp : TEST_Op<"indexElementsAttr"> {
|
|||
}
|
||||
|
||||
def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>]> {
|
||||
let arguments = (ins AnyTensor, AnyTensor);
|
||||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
@ -560,6 +561,13 @@ def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_ty
|
|||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
||||
def OpWithResultShapePerDimInterfaceOp : TEST_Op<"op_with_result_shape_per_dim_interface",
|
||||
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["reifyReturnTypeShapesPerResultDim"]>]> {
|
||||
let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
|
||||
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
|
||||
}
|
||||
|
||||
def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
|
||||
|
||||
def UpdateAttr : Pat<(I32ElementsAttrOp $attr),
|
||||
|
|
|
@ -128,11 +128,13 @@ static void reifyReturnShape(Operation *op) {
|
|||
// Use permutations of 2 args as operands.
|
||||
auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
|
||||
SmallVector<Value, 2> shapes;
|
||||
if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)))
|
||||
if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)) ||
|
||||
!llvm::hasSingleElement(shapes))
|
||||
return;
|
||||
for (auto it : llvm::enumerate(shapes))
|
||||
for (auto it : llvm::enumerate(shapes)) {
|
||||
op->emitRemark() << "value " << it.index() << ": "
|
||||
<< it.value().getDefiningOp();
|
||||
}
|
||||
}
|
||||
|
||||
struct TestReturnTypeDriver
|
||||
|
|
Loading…
Reference in New Issue