[mlir] Split linalg reshape ops into expand/collapse.

Differential Revision: https://reviews.llvm.org/D103548
This commit is contained in:
Alexander Belyaev 2021-06-03 11:33:56 +02:00
parent 1de1887f5f
commit 485c21be8a
25 changed files with 826 additions and 622 deletions

View File

@ -330,7 +330,12 @@ class Linalg_ReshapeLikeOp<string mnemonic, list<OpTrait> traits = []> :
// be either a contracting or expanding reshape.
OpBuilder<(ins "Type":$resultType, "Value":$src,
"ArrayRef<ReassociationIndices>":$reassociation,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
build($_builder, $_state, resultType, src, attrs);
$_state.addAttribute("reassociation",
getReassociationIndicesAttribute($_builder, reassociation));
}]>,
OpBuilder<(ins "Type":$resultType, "Value":$src,
"ArrayRef<ReassociationExprs>":$reassociation,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
@ -355,21 +360,33 @@ class Linalg_ReshapeLikeOp<string mnemonic, list<OpTrait> traits = []> :
return reassociationIndices;
};
}];
let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
}
def IndexListArrayAttr :
TypedArrayAttrBase<I64ArrayAttr, "Array of 64-bit integer array attributes">;
def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",
class Linalg_ReshapeOp<string mnemonic> : Linalg_ReshapeLikeOp<mnemonic,
[DeclareOpInterfaceMethods<ViewLikeOpInterface>]>,
Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
Results<(outs AnyStridedMemRef:$result)> {
let summary = "linalg.reshape produces a new view into the operand view";
let extraClassDeclaration = commonExtraClassDeclaration # [{
MemRefType getSrcType() { return src().getType().cast<MemRefType>(); }
MemRefType getResultType() { return result().getType().cast<MemRefType>(); }
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
let printer = [{ return ::print(p, *this); }];
}
def Linalg_ExpandShapeOp : Linalg_ReshapeOp<"expand_shape"> {
let summary = "operation to produce a memref with a higher rank.";
let description = [{
The `linalg.reshape` op produces a new view whose sizes are a reassociation
of the original `view`. Depending on whether or not the reassociated
MemRefType is contiguous, the resulting memref may require explicit alloc
and copies.
The `linalg.expand_shape` op produces a new view with a higher rank whose
sizes are a reassociation of the original `view`. Depending on whether or
not the reassociated MemRefType is contiguous, the resulting memref may
require explicit alloc and copies.
A reassociation is defined as a continuous grouping of dimensions and is
represented with an array of I64ArrayAttr attribute.
@ -381,85 +398,67 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",
All other cases are undefined behavior and a reshape op may not lower to
LLVM if it cannot be proven statically that it does not require alloc+copy.
A reshape may either collapse or expand dimensions, depending on the
relationship between source and target memref ranks. The verification rule
is that the reassociation maps are applied to the memref with the larger
rank to obtain the memref with the smaller rank. In the case of a dimension
expansion, the reassociation maps can be interpreted as inverse maps.
The operand memref type when dimensions can be zero-ranked if the result
memref type is statically shaped with all dimensions being unit extent. In
such case the reassociation map is empty.
The result memref type of a reshape when dimensions are collapsed
(operand memref type when dimensions are expanded) can be
zero-ranked if the operand memref type (or the result memref type
when dimensions are expanded) is statically shaped with all
dimensions being unit extent. In such cases the reassociation map
is empty.
The verification rule is that the reassociation maps are applied to the
result memref with the larger rank to obtain the operand memref with the
smaller rank.
Example:
```mlir
// Dimension expansion i -> (i', j') and (k) -> (k')
%1 = linalg.expand_shape %0 [[0, 1], [2]] :
memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
```
}];
}
def Linalg_CollapseShapeOp : Linalg_ReshapeOp<"collapse_shape"> {
let summary = "operation to produce a memref with a smaller rank.";
let description = [{
The `linalg.collapse_shape` op produces a new view with a smaller rank
whose sizes are a reassociation of the original `view`. Depending on
whether or not the reassociated MemRefType is contiguous, the resulting
memref may require explicit alloc and copies.
A reassociation is defined as a continuous grouping of dimensions and is
represented with an array of I64ArrayAttr attribute.
For now, it is assumed that either:
1. a reassociation produces and consumes contiguous MemRefType or,
2. the reshape op will be folded into its consumers (by changing the shape
of the computations).
All other cases are undefined behavior and a reshape op may not lower to
LLVM if it cannot be proven statically that it does not require alloc+copy.
The result memref type of a reshape can be zero-ranked if the operand
memref type is statically shaped with all dimensions being unit extent. In
such case the reassociation map is empty.
The verification rule is that the reassociation maps are applied to the
operand memref with the larger rank to obtain the result memref with the
smaller rank.
Examples:
```mlir
// Dimension collapse (i, j) -> i' and k -> k'
%1 = linalg.reshape %0 [[0, 1], [2]] :
%1 = linalg.collapse_shape %0 [[0, 1], [2]] :
memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
```
```mlir
// Dimension expansion i -> (i', j') and (k) -> (k')
%1 = linalg.reshape %0 [[0, 1], [2]] :
memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
```
}];
let extraClassDeclaration = commonExtraClassDeclaration # [{
MemRefType getSrcType() { return src().getType().cast<MemRefType>(); }
MemRefType getResultType() { return result().getType().cast<MemRefType>(); }
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
}
def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<
"tensor_reshape",
class Linalg_TensorReshapeOp<string mnemonic> : Linalg_ReshapeLikeOp<
mnemonic,
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapesPerResultDim"]>]>,
Arguments<(ins AnyTensor:$src,
IndexListArrayAttr:$reassociation)>,
Results<(outs AnyTensor:$result)> {
let summary = "linalg.tensor_reshape produces a new reshaped tensor.";
let description = [{
The `linalg.reshape` op produces a new tensor whose sizes are a
reassociation of the original `src`.
A reassociation is defined as a continuous grouping of dimensions and is
represented with an array of I64ArrayAttr attribute.
A reshape may either collapse or expand dimensions, depending on the
relationship between source and target tensor ranks. The verification rule
is that the reassociation maps are applied to the tensor with the larger
rank to obtain the tensor with the smaller rank. In the case of a dimension
expansion, the reassociation maps can be interpreted as inverse maps.
The result tensor type of a reshape when dimensions are collapsed
(operand tensor type when dimensions are expanded) can be
zero-ranked if the operand tensor type (or the result tensor type
when dimensions are expanded) is statically shaped with all
dimensions being unit extent. In such cases the reassociation map
is empty.
Examples:
```mlir
// Dimension collapse (i, j) -> i' and k -> k'
%b = linalg.tensor_reshape %a [[0, 1], [2]]
: tensor<?x?x?xf32> into tensor<?x?xf32>
```
```mlir
// Dimension expansion i -> (i', j') and (k) -> (k')
%b = linalg.tensor_reshape %a [[0, 1], [2]]
: tensor<?x?xf32> into tensor<?x?x?xf32>
```
}];
let extraClassDeclaration = commonExtraClassDeclaration # [{
RankedTensorType getSrcType() {
return src().getType().cast<RankedTensorType>();
@ -474,6 +473,60 @@ def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<
let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
}
def Linalg_TensorExpandShapeOp : Linalg_TensorReshapeOp<"tensor_expand_shape"> {
let summary = "operation to produce a tensor with a higher rank";
let description = [{
The `linalg.tensor_expand_shape` op produces a new tensor with a higher
rank whose sizes are a reassociation of the original `src`.
A reassociation is defined as a continuous grouping of dimensions and is
represented with an array of I64ArrayAttr attribute.
The verification rule is that the reassociation maps are applied to the
result tensor with the higher rank to obtain the operand tensor with the
smaller rank.
The operand tensor type of a reshape can be zero-ranked if the result
tensor type is statically shaped with all dimensions being unit extent. In
such cases the reassociation map is empty.
Examples:
```mlir
// Dimension expansion i -> (i', j') and (k) -> (k')
%b = linalg.tensor_expand_shape %a [[0, 1], [2]]
: tensor<?x?xf32> into tensor<?x?x?xf32>
```
}];
}
def Linalg_TensorCollapseShapeOp : Linalg_TensorReshapeOp<"tensor_collapse_shape"> {
let summary = "operation to produce a tensor with a smaller rank";
let description = [{
The `linalg.tensor_collapse_shape` op produces a new tensor with a smaller
rank whose sizes are a reassociation of the original `src`.
A reassociation is defined as a continuous grouping of dimensions and is
represented with an array of I64ArrayAttr attribute.
The verification rule is that the reassociation maps are applied to the
operand tensor with the higher rank to obtain the result tensor with the
smaller rank.
The result tensor type of a reshape can be zero-ranked if the operand
tensor type is statically shaped with all dimensions being unit extent. In
such case the reassociation map is empty.
Examples:
```mlir
// Dimension collapse (i, j) -> i' and k -> k'
%b = linalg.tensor_collapse_shape %a [[0, 1], [2]]
: tensor<?x?x?xf32> into tensor<?x?xf32>
```
}];
}
def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>,
Arguments<(ins Variadic<AnyType>:$values)> {
let summary = "Linalg yield operation";

View File

@ -95,9 +95,11 @@ public:
// ReshapeOp creates a new view descriptor of the proper rank.
// For now, the only conversion supported is for target MemRef with static sizes
// and strides.
template <typename ReshapeOp>
class ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> {
public:
using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
LogicalResult
matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands,
@ -118,8 +120,9 @@ public:
ReshapeOpAdaptor adaptor(operands);
MemRefDescriptor baseDesc(adaptor.src());
Location loc = reshapeOp->getLoc();
auto desc = MemRefDescriptor::undef(rewriter, reshapeOp->getLoc(),
typeConverter->convertType(dstType));
auto desc =
MemRefDescriptor::undef(rewriter, reshapeOp->getLoc(),
this->typeConverter->convertType(dstType));
desc.setAllocatedPtr(rewriter, loc, baseDesc.allocatedPtr(rewriter, loc));
desc.setAlignedPtr(rewriter, loc, baseDesc.alignedPtr(rewriter, loc));
desc.setOffset(rewriter, loc, baseDesc.offset(rewriter, loc));
@ -149,7 +152,8 @@ public:
/// Populate the given list with patterns that convert from Linalg to LLVM.
void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<RangeOpConversion, ReshapeOpConversion, YieldOpConversion>(
patterns.add<RangeOpConversion, ReshapeOpConversion<ExpandShapeOp>,
ReshapeOpConversion<CollapseShapeOp>, YieldOpConversion>(
converter);
// Populate the type conversions for the linalg types.

View File

@ -191,7 +191,8 @@ void ConvertLinalgToStandardPass::runOnOperation() {
target.addLegalDialect<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
StandardOpsDialect>();
target.addLegalOp<ModuleOp, FuncOp, ReturnOp>();
target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>();
target.addLegalOp<linalg::ExpandShapeOp, linalg::CollapseShapeOp,
linalg::RangeOp>();
RewritePatternSet patterns(&getContext());
populateLinalgToStandardConversionPatterns(patterns);
if (failed(applyFullConversion(module, target, std::move(patterns))))

View File

@ -1188,16 +1188,20 @@ public:
getIdentityExprs(resultTy.getShape().size())};
auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>(
Value collapsedOp = rewriter.create<linalg::TensorCollapseShapeOp>(
loc, collapsedTy, args[0], collapsingMap);
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
reshape, resultTy, collapsedOp, expandingMap);
return success();
}
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshape, resultTy, args[0], reassociationMap);
if (resultTy.getRank() < args[0].getType().cast<ShapedType>().getRank())
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
reshape, resultTy, args[0], reassociationMap);
else
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
reshape, resultTy, args[0], reassociationMap);
return success();
}

View File

@ -882,13 +882,14 @@ struct FoldInitTensorWithSubTensorOp : public OpRewritePattern<SubTensorOp> {
}
};
template <typename TensorReshapeOp>
struct FoldInitTensorWithTensorReshapeOp
: public OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
if (!reshapeOp.src().getDefiningOp<InitTensorOp>())
if (!reshapeOp.src().template getDefiningOp<InitTensorOp>())
return failure();
Location loc = reshapeOp.getLoc();
SmallVector<SmallVector<Value>, 4> resultShapes;
@ -912,7 +913,9 @@ struct FoldInitTensorWithTensorReshapeOp
void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldInitTensorWithSubTensorOp, FoldInitTensorWithTensorReshapeOp,
results.add<FoldInitTensorWithSubTensorOp,
FoldInitTensorWithTensorReshapeOp<TensorExpandShapeOp>,
FoldInitTensorWithTensorReshapeOp<TensorCollapseShapeOp>,
ReplaceStaticShapeDims>(context);
}
@ -1206,12 +1209,24 @@ static void print(OpAsmPrinter &p, ReshapeLikeOp op) {
p << ": " << op.src().getType() << " into " << op.getType();
}
static void print(OpAsmPrinter &p, linalg::ReshapeOp op) {
print<linalg::ReshapeOp>(p, op);
static void print(OpAsmPrinter &p, linalg::ExpandShapeOp op) {
print<linalg::ExpandShapeOp>(p, op);
}
static void print(OpAsmPrinter &p, linalg::TensorReshapeOp op) {
print<linalg::TensorReshapeOp>(p, op);
static void print(OpAsmPrinter &p, linalg::CollapseShapeOp op) {
print<linalg::CollapseShapeOp>(p, op);
}
static void print(OpAsmPrinter &p, linalg::TensorExpandShapeOp op) {
print<linalg::TensorExpandShapeOp>(p, op);
}
static void print(OpAsmPrinter &p, linalg::TensorCollapseShapeOp op) {
print<linalg::TensorCollapseShapeOp>(p, op);
}
static constexpr StringRef getReassociationAttrName() {
return "reassociation";
}
static ParseResult parseReshapeLikeOp(OpAsmParser &parser,
@ -1253,7 +1268,7 @@ static ParseResult parseReshapeLikeOp(OpAsmParser &parser,
break;
}
result.addAttribute(ReshapeOp::getReassociationAttrName(),
result.addAttribute(getReassociationAttrName(),
b.getArrayAttr(reassociation));
// Parse optional attributes.
@ -1334,36 +1349,10 @@ struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType();
ShapedType intermediateType = reshapeOp.getSrcType();
ShapedType resultType = reshapeOp.getResultType();
auto areReshapeOpsFoldable = [](ShapedType largerType,
ShapedType intermediateType,
ShapedType smallerType) -> bool {
return largerType.getRank() > intermediateType.getRank() &&
intermediateType.getRank() > smallerType.getRank();
};
Optional<SmallVector<ReassociationIndices>> reassociationIndices =
llvm::None;
// Check if producer and consumer are both expanding dims or both collapsing
// dims. In this case, try to compose the affine maps. This works for
// dynamic shapes too.
if (areReshapeOpsFoldable(resultType, intermediateType,
srcReshapeSrcType) ||
areReshapeOpsFoldable(srcReshapeSrcType, intermediateType,
resultType)) {
reassociationIndices = collapseReassociationIndices(
srcReshapeOp.getReassociationMaps(), reshapeOp.getReassociationMaps(),
rewriter.getContext());
}
if (!reassociationIndices) {
// If the source reshape can be collapsed/expanded into the target reshape
// they can still be folded. This can only be reasoned about statically
// for cases where
// - either all shapes are static, or
// - The number of dynamic dimensions matches in the source of source and
// result with all other dimensions being 1.
reassociationIndices =
getReassociationIndicesForReshape(srcReshapeSrcType, resultType);
}
collapseReassociationIndices(srcReshapeOp.getReassociationMaps(),
reshapeOp.getReassociationMaps(),
rewriter.getContext());
if (!reassociationIndices)
return failure();
rewriter.replaceOpWithNewOp<ReshapeOpTy>(
@ -1371,15 +1360,55 @@ struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
return success();
}
};
/// Pattern to collapse producer/consumer reshape ops that are both collapsing
/// dimensions or are both expanding dimensions.
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
struct CollapseMixedReshapeOps : public OpRewritePattern<ReshapeOpTy> {
using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
PatternRewriter &rewriter) const override {
auto srcReshapeOp =
reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
if (!srcReshapeOp)
return failure();
ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType();
ShapedType intermediateType = reshapeOp.getSrcType();
ShapedType resultType = reshapeOp.getResultType();
// If the source reshape can be collapsed/expanded into the target reshape
// they can still be folded. This can only be reasoned about statically
// for cases where
// - either all shapes are static, or
// - The number of dynamic dimensions matches in the source of source and
// result with all other dimensions being 1.
Optional<SmallVector<ReassociationIndices>> reassociationIndices =
getReassociationIndicesForReshape(srcReshapeSrcType, resultType);
if (!reassociationIndices)
return failure();
bool originalOpExpands =
intermediateType.getRank() > srcReshapeSrcType.getRank();
bool resultingOpExpands =
resultType.getRank() > srcReshapeSrcType.getRank();
if (!(resultingOpExpands ^ originalOpExpands))
rewriter.replaceOpWithNewOp<InverseReshapeOpTy>(
reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
else
rewriter.replaceOpWithNewOp<ReshapeOpTy>(
reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
return success();
}
};
} // namespace
template <typename ReshapeOpTy>
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
ArrayRef<Attribute> operands) {
// Fold producer-consumer reshape ops that where the operand type of the
// producer is same as the return type of the consumer.
ReshapeOpTy reshapeSrcOp =
reshapeOp.src().template getDefiningOp<ReshapeOpTy>();
auto reshapeSrcOp =
reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
return reshapeSrcOp.src();
// Reshape of a constant can be replaced with a new constant.
@ -1564,20 +1593,38 @@ convertReassociationIndicesToExprs(
return reassociationMaps;
}
SmallVector<AffineMap, 4> ReshapeOp::getReassociationMaps() {
SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
return getSymbolLessAffineMaps(getReassociationExprs());
}
SmallVector<ReassociationExprs, 4> ReshapeOp::getReassociationExprs() {
SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
OpBuilder b(this->getContext());
return convertReassociationIndicesToExprs(b, getReassociationIndices());
}
SmallVector<AffineMap, 4> TensorReshapeOp::getReassociationMaps() {
SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
return getSymbolLessAffineMaps(getReassociationExprs());
}
SmallVector<ReassociationExprs, 4> TensorReshapeOp::getReassociationExprs() {
SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
OpBuilder b(this->getContext());
return convertReassociationIndicesToExprs(b, getReassociationIndices());
}
SmallVector<AffineMap, 4> TensorCollapseShapeOp::getReassociationMaps() {
return getSymbolLessAffineMaps(getReassociationExprs());
}
SmallVector<ReassociationExprs, 4>
TensorCollapseShapeOp::getReassociationExprs() {
OpBuilder b(this->getContext());
return convertReassociationIndicesToExprs(b, getReassociationIndices());
}
SmallVector<AffineMap, 4> TensorExpandShapeOp::getReassociationMaps() {
return getSymbolLessAffineMaps(getReassociationExprs());
}
SmallVector<ReassociationExprs, 4>
TensorExpandShapeOp::getReassociationExprs() {
OpBuilder b(this->getContext());
return convertReassociationIndicesToExprs(b, getReassociationIndices());
}
/// For reshape op compute the shape at dimension `dimIndex` of the output in
/// terms of shape of the `src`, when the reshape op is a collapsing
/// operation. It is the product of the shape of the collapsed dimensions of the
@ -1708,7 +1755,7 @@ getReassociationIndicesAttribute(OpBuilder &b,
return b.getArrayAttr(reassociationAttr);
}
void mlir::linalg::ReshapeOp::build(
void mlir::linalg::ExpandShapeOp::build(
OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
@ -1717,20 +1764,26 @@ void mlir::linalg::ReshapeOp::build(
memRefType, getSymbolLessAffineMaps(
convertReassociationIndicesToExprs(b, reassociation)));
build(b, result, resultType, src, attrs);
result.addAttribute(ReshapeOp::getReassociationAttrName(),
result.addAttribute(getReassociationAttrName(),
getReassociationIndicesAttribute(b, reassociation));
}
void mlir::linalg::ReshapeOp::build(
OpBuilder &b, OperationState &result, Type resultType, Value src,
Value mlir::linalg::ExpandShapeOp::getViewSource() { return src(); }
void mlir::linalg::CollapseShapeOp::build(
OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
auto memRefType = src.getType().cast<MemRefType>();
auto resultType = computeReshapeCollapsedType(
memRefType, getSymbolLessAffineMaps(
convertReassociationIndicesToExprs(b, reassociation)));
build(b, result, resultType, src, attrs);
result.addAttribute(ReshapeOp::getReassociationAttrName(),
result.addAttribute(getReassociationAttrName(),
getReassociationIndicesAttribute(b, reassociation));
}
Value mlir::linalg::ReshapeOp::getViewSource() { return src(); }
Value mlir::linalg::CollapseShapeOp::getViewSource() { return src(); }
/// Verify that shapes of the reshaped types using following rules
/// 1) if a dimension in the collapsed type is static, then the corresponding
@ -1785,18 +1838,17 @@ static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
// Common verifier for reshape-like types. Fills `expandedType` and
// `collapsedType` with the proper `src` or `result` type.
template <typename Op, typename T>
static LogicalResult verifyReshapeLikeTypes(Op op, T &expandedType,
T &collapsedType) {
expandedType = op.getSrcType();
collapsedType = op.getResultType();
template <typename Op, typename T,
bool isExpansion = std::is_same<Op, TensorExpandShapeOp>::value ||
std::is_same<Op, ExpandShapeOp>::value>
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
T collapsedType) {
unsigned expandedRank = expandedType.getRank();
unsigned collapsedRank = collapsedType.getRank();
bool isCollapse = expandedRank > collapsedRank;
if (!isCollapse) {
std::swap(expandedRank, collapsedRank);
std::swap(expandedType, collapsedType);
}
if (expandedRank < collapsedRank)
return op.emitOpError("expected the type ")
<< expandedType
<< " to have higher rank than the type = " << collapsedType;
if (expandedRank == 0)
return op.emitOpError("expected non-zero memref ranks");
if (expandedRank == collapsedRank)
@ -1825,11 +1877,13 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T &expandedType,
if (!isReassociationValid(maps, &invalidIdx))
return op.emitOpError("expected reassociation map #")
<< invalidIdx << " to be valid and contiguous";
return verifyReshapeLikeShapes(op, collapsedType, expandedType, !isCollapse);
return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion);
}
static LogicalResult verify(ReshapeOp op) {
MemRefType expandedType, collapsedType;
template <typename TensorReshapeOp>
static LogicalResult verifyReshapeOp(TensorReshapeOp op,
MemRefType expandedType,
MemRefType collapsedType) {
if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
return failure();
auto maps = op.getReassociationMaps();
@ -1840,9 +1894,24 @@ static LogicalResult verify(ReshapeOp op) {
return success();
}
void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CollapseReshapeOps<ReshapeOp>>(context);
static LogicalResult verify(ExpandShapeOp op) {
return verifyReshapeOp(op, op.getResultType(), op.getSrcType());
}
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CollapseReshapeOps<ExpandShapeOp>,
CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
}
static LogicalResult verify(CollapseShapeOp op) {
return verifyReshapeOp(op, op.getSrcType(), op.getResultType());
}
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CollapseReshapeOps<CollapseShapeOp>,
CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>>(context);
}
//===----------------------------------------------------------------------===//
@ -1877,7 +1946,7 @@ computeTensorReshapeCollapsedType(RankedTensorType type,
return RankedTensorType::get(newShape, type.getElementType());
}
void mlir::linalg::TensorReshapeOp::build(
void mlir::linalg::TensorCollapseShapeOp::build(
OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
@ -1886,21 +1955,27 @@ void mlir::linalg::TensorReshapeOp::build(
getSymbolLessAffineMaps(
convertReassociationIndicesToExprs(b, reassociation)));
build(b, result, resultType, src, attrs);
result.addAttribute(ReshapeOp::getReassociationAttrName(),
result.addAttribute(getReassociationAttrName(),
getReassociationIndicesAttribute(b, reassociation));
}
void mlir::linalg::TensorReshapeOp::build(
OpBuilder &b, OperationState &result, Type resultType, Value src,
void mlir::linalg::TensorExpandShapeOp::build(
OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
auto resultType = computeTensorReshapeCollapsedType(
src.getType().cast<RankedTensorType>(),
getSymbolLessAffineMaps(
convertReassociationIndicesToExprs(b, reassociation)));
build(b, result, resultType, src, attrs);
result.addAttribute(ReshapeOp::getReassociationAttrName(),
result.addAttribute(getReassociationAttrName(),
getReassociationIndicesAttribute(b, reassociation));
}
static LogicalResult verify(TensorReshapeOp op) {
RankedTensorType expandedType, collapsedType;
template <typename TensorReshapeOp>
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
RankedTensorType expandedType,
RankedTensorType collapsedType) {
if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
return failure();
@ -1913,9 +1988,18 @@ static LogicalResult verify(TensorReshapeOp op) {
return success();
}
static LogicalResult verify(TensorExpandShapeOp op) {
return verifyTensorReshapeOp(op, op.getResultType(), op.getSrcType());
}
static LogicalResult verify(TensorCollapseShapeOp op) {
return verifyTensorReshapeOp(op, op.getSrcType(), op.getResultType());
}
namespace {
/// Reshape of a splat constant can be replaced with a constant of the result
/// type.
template <typename TensorReshapeOp>
struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
@ -1936,11 +2020,12 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
///
/// For such op chains, we can create new linalg.fill ops with the result
/// type of the linalg.tensor_reshape op.
template <typename TensorReshapeOp>
struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
auto oldFill = reshapeOp.src().getDefiningOp<FillOp>();
auto oldFill = reshapeOp.src().template getDefiningOp<FillOp>();
if (!oldFill)
return failure();
@ -1955,14 +2040,38 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
};
} // namespace
void TensorReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CollapseReshapeOps<TensorReshapeOp>, FoldFillWithTensorReshape,
FoldInitTensorWithTensorReshapeOp, FoldReshapeWithConstant>(
context);
void TensorExpandShapeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results
.add<CollapseReshapeOps<TensorExpandShapeOp>,
CollapseMixedReshapeOps<TensorExpandShapeOp, TensorCollapseShapeOp>,
FoldFillWithTensorReshape<TensorExpandShapeOp>,
FoldInitTensorWithTensorReshapeOp<TensorExpandShapeOp>,
FoldReshapeWithConstant<TensorExpandShapeOp>>(context);
}
LogicalResult TensorReshapeOp::reifyReturnTypeShapesPerResultDim(
void TensorCollapseShapeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results
.add<CollapseReshapeOps<TensorCollapseShapeOp>,
CollapseMixedReshapeOps<TensorCollapseShapeOp, TensorExpandShapeOp>,
FoldFillWithTensorReshape<TensorCollapseShapeOp>,
FoldInitTensorWithTensorReshapeOp<TensorCollapseShapeOp>,
FoldReshapeWithConstant<TensorCollapseShapeOp>>(context);
}
LogicalResult TensorExpandShapeOp::reifyReturnTypeShapesPerResultDim(
OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
auto resultShape =
getAsValues(b, getLoc(),
getReshapeOutputShapeFromInputShape(
b, getLoc(), src(), getResultType().getShape(),
getReassociationMaps()));
reifiedReturnShapes.emplace_back(std::move(resultShape));
return success();
}
LogicalResult TensorCollapseShapeOp::reifyReturnTypeShapesPerResultDim(
OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
auto resultShape =
getAsValues(b, getLoc(),
@ -2753,13 +2862,23 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) {
// TODO: Consider making all this boilerplate easy to autogenerate
// with Tablegen. This seems a desirable property in the context of
// OpInterfaces where a Linalg "named" op **isa** LinalgOp.
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
return foldReshapeOp(*this, operands);
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
}
OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) {
return foldReshapeOp(*this, operands);
OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
}
OpFoldResult TensorExpandShapeOp::fold(ArrayRef<Attribute> operands) {
return foldReshapeOp<TensorExpandShapeOp, TensorCollapseShapeOp>(*this,
operands);
}
OpFoldResult TensorCollapseShapeOp::fold(ArrayRef<Attribute> operands) {
return foldReshapeOp<TensorCollapseShapeOp, TensorExpandShapeOp>(*this,
operands);
}
//===----------------------------------------------------------------------===//

View File

@ -149,17 +149,25 @@ public:
/// Conversion pattern that replaces `linalg.tensor_reshape` with
/// `linalg.reshape`.
template <typename TensorReshapeOp,
typename Adaptor = typename TensorReshapeOp::Adaptor>
class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> {
public:
using OpConversionPattern<TensorReshapeOp>::OpConversionPattern;
using ReshapeOp = typename std::conditional_t<
std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value, ExpandShapeOp,
CollapseShapeOp>;
LogicalResult
matchAndRewrite(TensorReshapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
linalg::TensorReshapeOpAdaptor adaptor(operands, op->getAttrDictionary());
rewriter.replaceOpWithNewOp<linalg::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
adaptor.src(), adaptor.reassociation());
Adaptor adaptor(operands, op->getAttrDictionary());
rewriter.replaceOpWithNewOp<ReshapeOp>(op,
this->getTypeConverter()
->convertType(op.getType())
.template cast<MemRefType>(),
adaptor.src(),
adaptor.reassociation());
return success();
}
};
@ -348,7 +356,8 @@ void mlir::linalg::populateLinalgBufferizePatterns(
BufferizeAnyLinalgOp,
BufferizeFillOp,
BufferizeInitTensorOp,
BufferizeTensorReshapeOp,
BufferizeTensorReshapeOp<TensorExpandShapeOp>,
BufferizeTensorReshapeOp<TensorCollapseShapeOp>,
SubTensorOpConverter,
SubTensorInsertOpConverter
>(typeConverter, patterns.getContext());

View File

@ -31,7 +31,7 @@ static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
// FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
// a tensor<dtype> instead.
return builder.create<linalg::TensorReshapeOp>(
return builder.create<linalg::TensorCollapseShapeOp>(
loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{});
}
@ -159,8 +159,8 @@ public:
/// Canonicalizes the pattern of the form
///
/// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
/// %reshaped_tensor = linalg.tensor_reshape %tensor [] : tensor<1xi32> into
/// tensor<i32>
/// %reshaped_tensor = linalg.tensor_collapse_shape %tensor []
/// : tensor<1xi32> into tensor<i32>
/// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32>
///
/// to just %element.
@ -170,10 +170,11 @@ struct ExtractFromReshapeFromElements
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
PatternRewriter &rewriter) const final {
if (extract.indices().size() != 0)
if (!extract.indices().empty())
return failure();
auto tensorReshape = extract.tensor().getDefiningOp<TensorReshapeOp>();
auto tensorReshape =
extract.tensor().getDefiningOp<TensorCollapseShapeOp>();
if (tensorReshape == nullptr)
return failure();

View File

@ -362,10 +362,11 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
for (auto operand : llvm::enumerate(values)) {
if (operand.value().getType() == newInputOutputTypes[flattenedIdx])
res.push_back(operand.value());
else
res.push_back(rewriter.create<linalg::TensorReshapeOp>(
else {
res.push_back(rewriter.create<TensorCollapseShapeOp>(
loc, newInputOutputTypes[flattenedIdx], operand.value(),
convertAffineMapArrayToExprs(reassociationMaps[flattenedIdx])));
}
++flattenedIdx;
}
return res;
@ -395,11 +396,11 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
RankedTensorType origResultType = genericOp.getResult(result.index())
.getType()
.template cast<RankedTensorType>();
if (origResultType != result.value().getType())
resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>(
if (origResultType != result.value().getType()) {
resultReplacements.push_back(rewriter.create<TensorExpandShapeOp>(
loc, origResultType, result.value(),
convertAffineMapArrayToExprs(reassociationMaps[index])));
else
} else
resultReplacements.push_back(result.value());
}
rewriter.replaceOp(genericOp, resultReplacements);
@ -460,8 +461,8 @@ struct UseRankReducedSubTensorOp : public OpRewritePattern<SubTensorOp> {
Location loc = subTensorOp.getLoc();
Value newSubTensor = rewriter.create<SubTensorOp>(
loc, rankReducedType, subTensorOp.source(), offsets, sizes, strides);
rewriter.replaceOpWithNewOp<TensorReshapeOp>(subTensorOp, resultType,
newSubTensor, *reassociation);
rewriter.replaceOpWithNewOp<TensorExpandShapeOp>(
subTensorOp, resultType, newSubTensor, *reassociation);
return success();
}
};
@ -482,7 +483,7 @@ struct UseRankReducedSubTensorInsertOp
reassociation->size() == static_cast<size_t>(sourceType.getRank()))
return failure();
Location loc = insertOp.getLoc();
auto reshapedSource = rewriter.create<TensorReshapeOp>(
auto reshapedSource = rewriter.create<TensorCollapseShapeOp>(
loc, insertOp.source(), *reassociation);
rewriter.replaceOpWithNewOp<SubTensorInsertOp>(
insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(),
@ -500,7 +501,8 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns(
patterns.add<FoldUnitDimLoops, ReplaceUnitExtentTensors,
UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>(
context);
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
}
namespace {

View File

@ -313,8 +313,7 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
/// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
///
/// and reshape:
/// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
/// affine_map<(i, j, k, l) -> (j, k, l)>] :
/// %1 = linalg.tensor_collapse_shape %0 [[0], [0, 1, 2]] :
/// tensor<?x?x4x5xf32> into tensor<?x?xf32>
///
/// would be rewritten into:
@ -355,24 +354,21 @@ static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
resultExprs, context);
}
/// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is
/// true) or its producer (if `asProducer` is false) given the indexing map at
/// its use.
static bool isTensorReshapeOpFoldableByLinearization(TensorReshapeOp reshapeOp,
AffineMap useIndexMap,
bool asProducer) {
RankedTensorType returnType = reshapeOp.getResultType();
RankedTensorType operandType = reshapeOp.getSrcType();
// Reshape is fusable with its consumer (i.e. reshape as a producer) when its
// operand is of lesser rank than the result. Fusing when operand has higher
// rank will require use of mods and divs in the indexing maps of the fused op
// which would make it non-invertible. Similarly reshape is fused with its
// producer (i.e. reshape as consumer) only if the return type has lesser
// rank.
if ((asProducer && reshapeOp.getSrcType().hasStaticShape() &&
returnType.getRank() < operandType.getRank()) ||
(!asProducer && reshapeOp.getResultType().hasStaticShape() &&
operandType.getRank() < returnType.getRank()))
// TensorExpandShapeOp is fusable with its consumer (i.e. reshape as a
// producer). Fusing when operand has higher rank will require use of mods and
// divs in the indexing maps of the fused op which would make it non-invertible.
static bool isTensorReshapeOpFoldableByLinearization(
TensorExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
if (!asProducer && expandOp.getResultType().hasStaticShape())
return false;
return useIndexMap.isPermutation();
}
// TensorCollapseShapeOp is fusable with its producer (i.e. reshape as a
// consumer).
static bool isTensorReshapeOpFoldableByLinearization(
TensorCollapseShapeOp collapseOp, AffineMap useIndexMap, bool asProducer) {
if (asProducer && collapseOp.getSrcType().hasStaticShape())
return false;
return useIndexMap.isPermutation();
}
@ -405,17 +401,14 @@ static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
/// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
/// affine_map<(d0, d1, d2) -> (d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
/// %d = linalg.tensor_reshape %c
/// [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
/// affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
/// affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>]
/// %d = linalg.tensor_expand_shape %c [[0, 1], [2], [3, 4, 5]]
/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
///
/// The reshape can be folded into the `genericOp` if its loop dimensionality
/// is increased to match the result (operand) of the tensor_reshape when the
/// reshape is expanding (folding). The indexing_map of the fused tensor in the
/// `genericOp` and the reassociation map helps compute the indexing maps of
/// the modified op. For the above example, based on the reassociation map it
/// is increased to match the result (operand) of the tensor_expand_shape.
/// The indexing_map of the fused tensor in the `genericOp` and the
/// reassociation map helps compute the indexing maps of the modified op.
/// For the above example, based on the reassociation map it
/// can be concluded that
///
/// - The loop used to access the first dimension of the fused tensor is split
@ -443,14 +436,9 @@ static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
/// Since operands to the linalg generic are now 5D, reshapes can be introduced
/// to make it consistent
///
/// %0 = linalg.tensor_reshape %a
/// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e2),
/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e3, e4),
/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e5)]
/// %0 = linalg.tensor_expand_shape %a [[0, 1, 2], [3, 4], [5]]
/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
/// %1 = linalg.tensor_reshape %b
/// [affine_map<(e0, e1, e2, e3) -> (e0, e1, e2),
/// affine_map<(e0, e1, e2, e3) -> (e3)]
/// %1 = linalg.tensor_expand_shape %b [[0, 1, 2], [3]]
/// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
///
/// The added reshapes are again expanding patterns, so they will get fused
@ -614,11 +602,12 @@ static RankedTensorType getExpandedType(RankedTensorType originalType,
return RankedTensorType::get(expandedShape, originalType.getElementType());
}
/// Returns the reassociation maps to use in the `linalg.tensor_reshape`
/// operation to convert the operands of the origial operation to operands of
/// Returns the reassociation maps to use in the `linalg.tensor_expand_shape`
/// operation to convert the operands of the original operation to operands of
/// the expanded operation. The same method is used to compute the
/// `linalg.tensor_reshape` used to collapse the result of the expanded op to
/// get the value that can replace all uses of the results of the original op.
/// `linalg.tensor_collapse_shape` used to collapse the result of the expanded
/// op to get the value that can replace all uses of the results of the original
/// op.
static SmallVector<ReassociationIndices>
getReassociationForExpansion(AffineMap indexingMap,
const ExpansionInfo &expansionInfo) {
@ -678,25 +667,29 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
}
}
/// Implements the fusion of a tensor_reshape op and a generic op as explained
/// in `isFusableWithReshapeByExpansion`. Assumes that those conditions have
/// been satisfied.
/// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op
/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
/// that those conditions have been satisfied.
static Optional<SmallVector<Value>>
fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
OpOperand *fusableOpOperand,
PatternRewriter &rewriter) {
assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
"preconditions for fuse operation failed");
// Check if reshape is expanding or collapsing.
bool isExpanding =
reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank();
RankedTensorType expandedType =
isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
auto expandingReshapeOp = dyn_cast<TensorExpandShapeOp>(*reshapeOp);
auto collapsingReshapeOp = dyn_cast<TensorCollapseShapeOp>(*reshapeOp);
bool isExpanding = (expandingReshapeOp != nullptr);
RankedTensorType expandedType = isExpanding
? expandingReshapeOp.getResultType()
: collapsingReshapeOp.getSrcType();
ExpansionInfo expansionInfo;
if (failed(expansionInfo.compute(genericOp, fusableOpOperand,
reshapeOp.getReassociationMaps(),
expandedType.getShape(), rewriter)))
if (failed(expansionInfo.compute(
genericOp, fusableOpOperand,
isExpanding ? expandingReshapeOp.getReassociationMaps()
: collapsingReshapeOp.getReassociationMaps(),
expandedType.getShape(), rewriter)))
return llvm::None;
if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
@ -710,7 +703,8 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
SmallVector<Value> expandedOpOperands;
for (OpOperand *opOperand : genericOp.getInputOperands()) {
if (opOperand == fusableOpOperand) {
expandedOpOperands.push_back(reshapeOp.src());
expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
: collapsingReshapeOp.src());
continue;
}
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
@ -721,7 +715,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
// Reshape the operand to get the right type.
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
expandedOpOperands.push_back(rewriter.create<TensorReshapeOp>(
expandedOpOperands.push_back(rewriter.create<TensorExpandShapeOp>(
genericOp.getLoc(), expandedOperandType, opOperand->get(),
reassociation));
continue;
@ -739,7 +733,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
if (expandedOutputType != opOperand->get().getType()) {
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
outputs.push_back(rewriter.create<TensorReshapeOp>(
outputs.push_back(rewriter.create<TensorExpandShapeOp>(
genericOp.getLoc(), expandedOutputType, opOperand->get(),
reassociation));
}
@ -772,7 +766,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
genericOp.getTiedIndexingMap(
genericOp.getOutputOperand(resultNumber)),
expansionInfo);
resultVals.push_back(rewriter.create<TensorReshapeOp>(
resultVals.push_back(rewriter.create<TensorCollapseShapeOp>(
genericOp.getLoc(), opResult.getType(),
fusedOp->getResult(resultNumber), reassociation));
} else {
@ -785,18 +779,15 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
namespace {
/// Pattern to fold tensor_reshape op with its consumer by using the source of
/// the reshape op as the operand in the consumer (instead of the result of the
/// tensor_reshapeop) when the tensor_reshape op is collapsing. The
/// corresponding index map in the consumer needs to be modified to linearize
/// the folded dimension.
/// Pattern to fold tensor_expand_shape op with its consumer by using the source
/// of the reshape op as the operand in the consumer (instead of the result of
/// the tensor_collapse_shape). The corresponding index map in the consumer
/// needs to be modified to linearize the folded dimension.
///
/// For example,
///
/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
/// %0 = linalg.tensor_reshape %arg0
/// [affine_map<(i, j, k, l) -> (i)>, affine_map<(i, j, k, l) -> (j, k)>,
/// affine_map<(i, j, k, l) -> (l)>]
/// %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2], [3]]
/// tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
/// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
@ -809,7 +800,7 @@ namespace {
/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
/// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
/// -> tensor<?x?x4x?xf32>
template <bool foldUnitDimReshapesOnly>
template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
struct FoldProducerReshapeOpByLinearization
: public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
@ -820,14 +811,19 @@ struct FoldProducerReshapeOpByLinearization
return failure();
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
for (auto en : llvm::enumerate(inputOperands)) {
TensorReshapeOp reshapeOp =
en.value()->get().getDefiningOp<TensorReshapeOp>();
if (!reshapeOp ||
!isTensorReshapeOpFoldableByLinearization(
auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
if (!reshapeOp)
continue;
Value src = reshapeOp.src();
RankedTensorType operandType = reshapeOp.getSrcType();
RankedTensorType returnType = reshapeOp.getResultType();
if (!isTensorReshapeOpFoldableByLinearization(
reshapeOp, genericOp.getTiedIndexingMap(en.value()),
/*asProducer =*/true) ||
(foldUnitDimReshapesOnly &&
!isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
!isUnitDimExpansionOnly(returnType.getShape(),
reshapeOp.getReassociationMaps())))
continue;
@ -845,9 +841,8 @@ struct FoldProducerReshapeOpByLinearization
auto invMap = inversePermutation(fusedIndexMaps[en.index()]);
// Compute the indexing map to use for the result of the producer.
AffineMap modifiedMap =
linearizeCollapsedDims(invMap, reshapeOp.getResultType().getShape(),
reshapeOp.getReassociationMaps());
AffineMap modifiedMap = linearizeCollapsedDims(
invMap, returnType.getShape(), reshapeOp.getReassociationMaps());
for (AffineExpr expr : modifiedMap.getResults()) {
if (!expr.isPureAffine())
return failure();
@ -893,8 +888,7 @@ getReassociationIndices(ArrayRef<AffineMap> maps) {
///
/// For example,
///
/// %0 = linalg.tensor_reshape %A [
/// affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
/// %0 = linalg.tensor_expand_shape %A [[0, 1], [2]]
/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
/// %2 = linalg.generic {indexing_maps = [
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
@ -912,8 +906,7 @@ getReassociationIndices(ArrayRef<AffineMap> maps) {
/// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
/// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
/// } -> tensor<12544x16xf32>
/// %3 = linalg.tensor_reshape %2 [
/// #affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
/// %3 = linalg.tensor_expand_shape %2 [[0, 1], [2]]
/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
@ -932,17 +925,15 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
return failure();
int64_t destRank = genericOp.getNumParallelLoops();
SmallVector<Value> newOperands = genericOp.getInputOperands();
TensorReshapeOp reshapeFound;
// 1. Look for tensor_reshape operands and figure out save the dimensions
// merged.
TensorExpandShapeOp reshapeFound;
// 1. Look for tensor_expand_shape operands and figure out save the
// dimensions merged.
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
for (auto en : llvm::enumerate(inputOperands)) {
TensorReshapeOp reshapeOp =
en.value()->get().template getDefiningOp<TensorReshapeOp>();
if (!reshapeOp || reshapeOp.getSrcType().getRank() >
reshapeOp.getResultType().getRank()) {
auto reshapeOp =
en.value()->get().template getDefiningOp<TensorExpandShapeOp>();
if (!reshapeOp)
continue;
}
// TODO: We could support non-identity map as long as the merged
// dimensions are still contiguous.
if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
@ -1007,7 +998,7 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
auto newOutputType = RankedTensorType::get(
reshapeFound.getSrcType().getShape(),
output.getType().template cast<RankedTensorType>().getElementType());
Value newOutput = rewriter.create<TensorReshapeOp>(
Value newOutput = rewriter.create<TensorCollapseShapeOp>(
genericOp->getLoc(), newOutputType, output, reassociation);
newOutputTypes.push_back(newOutputType);
newOutputs.push_back(newOutput);
@ -1023,7 +1014,7 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
// 6. Reshape the so that the type matches the uses.
SmallVector<Value> newResults;
for (auto result : llvm::enumerate(newOp->getResults())) {
newResults.push_back(rewriter.create<TensorReshapeOp>(
newResults.push_back(rewriter.create<TensorExpandShapeOp>(
genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
result.value(), reassociation));
}
@ -1032,9 +1023,9 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
}
};
/// Pattern to fuse a tensor_reshape op with its consumer generic op, when the
/// reshape op is collapsing dimensions. The dimensionality of the loop in the
/// consumer is expanded.
/// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
/// when the reshape op is collapsing dimensions. The dimensionality of the loop
/// in the consumer is expanded.
class FoldWithProducerReshapeOpByExpansion
: public OpRewritePattern<GenericOp> {
public:
@ -1047,16 +1038,14 @@ public:
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
for (OpOperand *opOperand : genericOp.getInputOperands()) {
TensorReshapeOp reshapeOp =
opOperand->get().getDefiningOp<TensorReshapeOp>();
TensorCollapseShapeOp reshapeOp =
opOperand->get().getDefiningOp<TensorCollapseShapeOp>();
if (!reshapeOp)
continue;
// Fold only if
// - The tensor reshape op is folding.
// - All constraints of fusing with reshape by expansion are met.
if (reshapeOp.getSrcType().getRank() <
reshapeOp.getResultType().getRank() ||
!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
(!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
continue;
@ -1074,16 +1063,17 @@ private:
ControlElementwiseOpsFusionFn controlFoldingReshapes;
};
/// Pattern to fold tensor_reshape op with its producer. The corresponding index
/// map in the consumer needs to be modified to linearize the folded dimension.
template <bool foldUnitDimReshapesOnly>
/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
/// producer. The corresponding index map in the consumer needs to be modified
/// to linearize the folded dimension.
template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
struct FoldConsumerReshapeOpByLinearization
: public OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
if (!producer || !producer.hasTensorSemantics() ||
producer.getNumOutputs() != 1 ||
!isTensorReshapeOpFoldableByLinearization(
@ -1141,19 +1131,14 @@ struct FoldConsumerReshapeOpByLinearization
}
};
/// Pattern to fold a tensor_reshape op with its producer generic op if the
/// tensor_reshape op is expanding, by expanding the dimensionality of the loop
/// in the producer op.
/// Pattern to fold a tensor_expand_shape op with its producer generic op
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
: public OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
: public OpRewritePattern<TensorExpandShapeOp> {
using OpRewritePattern<TensorExpandShapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorExpandShapeOp reshapeOp,
PatternRewriter &rewriter) const override {
// Fold only if
// - The tensor reshape op is a expanding case.
// - All constraints of fusing with reshape by expansion are met.
if (reshapeOp.getSrcType().getRank() > reshapeOp.getResultType().getRank())
return failure();
// Fold only if all constraints of fusing with reshape by expansion are met.
GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
if (!producer || producer.getNumOutputs() != 1 ||
!isFusableWithReshapeByDimExpansion(producer,
@ -1260,9 +1245,14 @@ fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
const OpOperand &consumer) {
auto reshapeOp = producer.getDefiningOp<linalg::TensorReshapeOp>();
return !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
reshapeOp.getReassociationMaps());
auto expandShapeOp = producer.getDefiningOp<linalg::TensorExpandShapeOp>();
if (expandShapeOp)
return !isUnitDimExpansionOnly(expandShapeOp.getSrcType().getShape(),
expandShapeOp.getReassociationMaps());
auto collapseShapeOp =
producer.getDefiningOp<linalg::TensorCollapseShapeOp>();
return !isUnitDimExpansionOnly(collapseShapeOp.getSrcType().getShape(),
collapseShapeOp.getReassociationMaps());
}
namespace {
@ -1375,16 +1365,22 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns) {
patterns.add<FoldProducerReshapeOpByLinearization<false>,
FoldConsumerReshapeOpByLinearization<false>>(
patterns.getContext());
patterns
.add<FoldProducerReshapeOpByLinearization<false, TensorCollapseShapeOp>,
FoldProducerReshapeOpByLinearization<false, TensorExpandShapeOp>,
FoldConsumerReshapeOpByLinearization<false, TensorCollapseShapeOp>,
FoldConsumerReshapeOpByLinearization<false, TensorExpandShapeOp>>(
patterns.getContext());
}
void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns) {
patterns.add<FoldProducerReshapeOpByLinearization<true>,
FoldConsumerReshapeOpByLinearization<true>>(
patterns.getContext());
patterns
.add<FoldProducerReshapeOpByLinearization<true, TensorCollapseShapeOp>,
FoldProducerReshapeOpByLinearization<true, TensorExpandShapeOp>,
FoldConsumerReshapeOpByLinearization<true, TensorCollapseShapeOp>,
FoldConsumerReshapeOpByLinearization<true, TensorExpandShapeOp>>(
patterns.getContext());
}
void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
@ -1406,7 +1402,8 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
GenericOp::getCanonicalizationPatterns(patterns, context);
IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
}
void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {

View File

@ -61,7 +61,7 @@ func @test_abs(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK-LABEL: @test_broadcast
func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32>
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg0
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins([[RESHAPE]], %arg1 : tensor<f32>, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
// CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
// CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32
@ -79,7 +79,7 @@ func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32
// CHECK-LABEL: @test_broadcast_swapped_args
func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) -> tensor<2xf32> {
// CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32>
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg1
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg1
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[RESHAPE]] : tensor<2xf32>, tensor<f32>) outs([[INIT]] : tensor<2xf32>) {
// CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
// CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32
@ -98,8 +98,8 @@ func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) ->
// CHECK-LABEL: @test_multibroadcast
func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> {
// CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3] : tensor<2x3xf32>
// CHECK: [[RESHAPE1:%.+]] = linalg.tensor_reshape %arg0 {{\[}}[0, 1]]
// CHECK: [[RESHAPE2:%.+]] = linalg.tensor_reshape %arg1 {{\[}}[0, 1]]
// CHECK: [[RESHAPE1:%.+]] = linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1]]
// CHECK: [[RESHAPE2:%.+]] = linalg.tensor_collapse_shape %arg1 {{\[}}[0, 1]]
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RESHAPE1]], [[RESHAPE2]] : tensor<3xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2x3xf32>) {
// CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
// CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32
@ -467,7 +467,7 @@ func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// CHECK-LABEL: @test_reshape_downrank
func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 {{\[}}[0, 1]]
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1]]
%0 = "tosa.reshape"(%arg0) {new_shape = [6]} : (tensor<2x3xf32>) -> tensor<6xf32>
// CHECK: return [[RESHAPE]]
return %0 : tensor<6xf32>
@ -477,7 +477,7 @@ func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
// CHECK-LABEL: @test_reshape_uprank
func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 {{\[}}[0, 1]]
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]]
%0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<6xf32>) -> tensor<2x3xf32>
// CHECK: return [[RESHAPE]]
return %0 : tensor<2x3xf32>
@ -488,8 +488,8 @@ func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
// CHECK-LABEL: @test_reshape_samerank
func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
// CHECK-NEXT: %[[RESHAPE1:.*]] = linalg.tensor_reshape %[[ARG0]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[RESHAPE2:.*]] = linalg.tensor_reshape %[[RESHAPE1]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[RESHAPE1:.*]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[RESHAPE2:.*]] = linalg.tensor_expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
%0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<3x2xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: return %[[RESHAPE2]]
return %0 : tensor<2x3xf32>
@ -499,7 +499,7 @@ func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
// CHECK-LABEL: @test_reshape_downrank_6D
func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
// CHECK: linalg.tensor_reshape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]]
// CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]]
%0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
return %0 : tensor<6x5x77xf32>
}
@ -549,7 +549,7 @@ func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
// CHECK: ^bb0(%arg1: f32, %arg2: f32)
// CHECK: [[RES:%.+]] = addf %arg1, %arg2 : f32
// CHECK: linalg.yield [[RES]] : f32
// CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32>
// CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32>
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
// CHECK: [[INIT:%.+]] = linalg.init_tensor [5]
@ -559,7 +559,7 @@ func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
// CHECK: ^bb0(%arg1: f32, %arg2: f32)
// CHECK: [[RES:%.+]] = addf %arg1, %arg2 : f32
// CHECK: linalg.yield [[RES]] : f32
// CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32>
// CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32>
%1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5x1xf32>
// CHECK: constant 1.0
@ -600,7 +600,7 @@ func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
// CHECK: ^bb0(%arg1: i32, %arg2: i32)
// CHECK: [[RES:%.+]] = addi %arg1, %arg2 : i32
// CHECK: linalg.yield [[RES]] : i32
// CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32>
// CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32>
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32>
// CHECK: [[INIT:%.+]] = linalg.init_tensor [5]
@ -610,7 +610,7 @@ func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
// CHECK: ^bb0(%arg1: i32, %arg2: i32)
// CHECK: [[RES:%.+]] = addi %arg1, %arg2 : i32
// CHECK: linalg.yield [[RES]] : i32
// CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32>
// CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32>
%1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x1xi32>
// CHECK: constant 1
@ -650,7 +650,7 @@ func @reduce_bool(%arg0: tensor<5x4xi1>) -> () {
// CHECK: ^bb0(%arg1: i1, %arg2: i1)
// CHECK: [[RES:%.+]] = and %arg1, %arg2 : i1
// CHECK: linalg.yield [[RES]] : i1
// CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1>
// CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1>
%0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<5x4xi1>) -> tensor<1x4xi1>
// CHECK: constant false
@ -822,19 +822,19 @@ func @tile(%arg0 : tensor<2x3xi8>) -> () {
// CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 2, 1, 3]
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>)
// CHECK: linalg.yield %arg1 : i8
// CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1, 2], [3]]
// CHECK: linalg.tensor_collapse_shape [[GENERIC]] {{\[}}[0, 1, 2], [3]]
%0 = "tosa.tile"(%arg0) {multiples = [2, 1]} : (tensor<2x3xi8>) -> (tensor<4x3xi8>)
// CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 2, 2, 3]
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>)
// CHECK: linalg.yield %arg1 : i8
// CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1], [2, 3]]
// CHECK: linalg.tensor_collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]]
%1 = "tosa.tile"(%arg0) {multiples = [1, 2]} : (tensor<2x3xi8>) -> (tensor<2x6xi8>)
// CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2, 7, 3]
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>)
// CHECK: linalg.yield %arg1 : i8
// CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1], [2, 3]]
// CHECK: linalg.tensor_collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]]
%2 = "tosa.tile"(%arg0) {multiples = [5, 7]} : (tensor<2x3xi8>) -> (tensor<10x21xi8>)
return
@ -1097,7 +1097,7 @@ func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
// Initial piece computes the sum of the pooling region, with appropriate padding.
// CHECK: [[CONST:%.+]] = constant 0
// CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
// CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
// CHECK: [[CONST:%.+]] = constant 0
// CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 33, 62]
// CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]])
@ -1188,9 +1188,9 @@ func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>,
// CHECK: ^bb0(%arg3: f32, %arg4: f32): // no predecessors
// CHECK: linalg.yield %arg3 : f32
// CHECK: } -> tensor<1x5x5x33xf32>
// CHECK: [[DBIAS:%.+]] = linalg.tensor_reshape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]]
// CHECK: [[DBIAS:%.+]] = linalg.tensor_expand_shape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]]
// CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>)
// CHECK: linalg.tensor_reshape %3 {{\[}}[0], [1], [2], [3, 4]]
// CHECK: linalg.tensor_collapse_shape %3 {{\[}}[0], [1], [2], [3, 4]]
%2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> (tensor<1x5x5x33xf32>)
return
}
@ -1202,10 +1202,10 @@ func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>,
func @resize_nearest(%input: tensor<1x2x2x1xf32>) -> () {
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1]
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: %[[IDX0:.+]] = linalg.index 0
// CHECK: %[[IDX1:.+]] = linalg.index 1
// CHECK: %[[IDX2:.+]] = linalg.index 2
// CHECK: %[[IDX3:.+]] = linalg.index 3
// CHECK: %[[IDX0:.+]] = linalg.index 0
// CHECK: %[[IDX1:.+]] = linalg.index 1
// CHECK: %[[IDX2:.+]] = linalg.index 2
// CHECK: %[[IDX3:.+]] = linalg.index 3
// CHECK-DAG: %[[XYMIN:.+]] = constant 0
// CHECK-DAG: %[[YMAX:.+]] = constant 1
// CHECK-DAG: %[[XMAX:.+]] = constant 1
@ -1271,9 +1271,9 @@ func @resize_nearest(%input: tensor<1x2x2x1xf32>) -> () {
func @resize_bilinear(%input: tensor<1x2x2x1xf32>) -> () {
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1]
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: %[[IDX0:.+]] = linalg.index 0
// CHECK: %[[IDX1:.+]] = linalg.index 1
// CHECK: %[[IDX2:.+]] = linalg.index 2
// CHECK: %[[IDX0:.+]] = linalg.index 0
// CHECK: %[[IDX1:.+]] = linalg.index 1
// CHECK: %[[IDX2:.+]] = linalg.index 2
// CHECK: %[[IDX3:.+]] = linalg.index 3
// CHECK: %[[XYMIN:.+]] = constant 0
// CHECK: %[[YMAX:.+]] = constant 1
@ -1353,9 +1353,9 @@ func @resize_bilinear(%input: tensor<1x2x2x1xf32>) -> () {
func @resize_nearest_int(%input: tensor<1x2x2x1xi32>) -> () {
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1]
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: %[[IDX0:.+]] = linalg.index 0
// CHECK: %[[IDX1:.+]] = linalg.index 1
// CHECK: %[[IDX2:.+]] = linalg.index 2
// CHECK: %[[IDX0:.+]] = linalg.index 0
// CHECK: %[[IDX1:.+]] = linalg.index 1
// CHECK: %[[IDX2:.+]] = linalg.index 2
// CHECK: %[[IDX3:.+]] = linalg.index 3
// CHECK-DAG: %[[XYMIN:.+]] = constant 0
// CHECK-DAG: %[[YMAX:.+]] = constant 1
@ -1422,7 +1422,7 @@ func @resize_bilinear_int(%input: tensor<1x2x2x1xi8>) -> () {
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1]
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: %[[IDX0:.+]] = linalg.index 0
// CHECK: %[[IDX0:.+]] = linalg.index 0
// CHECK: %[[IDX3:.+]] = linalg.index 3
// CHECK: %[[XYMIN:.+]] = constant 0

View File

@ -253,15 +253,15 @@ func @bufferize_fill(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// -----
// CHECK-LABEL: func @bufferize_tensor_reshape(
// CHECK-LABEL: func @bufferize_tensor_collapse_shape(
// CHECK-SAME: %[[IN:.*]]: tensor<4x5xf32>
func @bufferize_tensor_reshape(%arg0: tensor<4x5xf32>) -> tensor<20xf32> {
%out = linalg.tensor_reshape %arg0 [[0, 1]] :
func @bufferize_tensor_collapse_shape(%arg0: tensor<4x5xf32>) -> tensor<20xf32> {
%out = linalg.tensor_collapse_shape %arg0 [[0, 1]] :
tensor<4x5xf32> into tensor<20xf32>
return %out : tensor<20xf32>
}
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[IN]] : memref<4x5xf32>
// CHECK: %[[RESHAPE:.*]] = linalg.reshape %[[MEMREF]] {{\[}}[0, 1]]
// CHECK: %[[RESHAPE:.*]] = linalg.collapse_shape %[[MEMREF]] {{\[}}[0, 1]]
// CHECK-SAME: : memref<4x5xf32> into memref<20xf32>
// CHECK: %[[TENSOR:.*]] = memref.tensor_load %[[RESHAPE]] : memref<20xf32>
// CHECK: return %[[TENSOR]]

View File

@ -46,9 +46,9 @@ func @memref_cast_into_tiled_loop(%arg0: memref<192xf32>) {
// CHECK-LABEL: zero_rank_reshape_multi
func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: return %arg0
%0 = linalg.tensor_reshape %arg0 [] : tensor<f32> into tensor<1xf32>
%1 = linalg.tensor_reshape %0 [[0, 1]] : tensor<1xf32> into tensor<1x1xf32>
%2 = linalg.tensor_reshape %1 [] : tensor<1x1xf32> into tensor<f32>
%0 = linalg.tensor_expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
%1 = linalg.tensor_expand_shape %0 [[0, 1]] : tensor<1xf32> into tensor<1x1xf32>
%2 = linalg.tensor_collapse_shape %1 [] : tensor<1x1xf32> into tensor<f32>
return %2 : tensor<f32>
}
@ -56,175 +56,175 @@ func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0, 1], [2], [3, 4]]
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4]]
: tensor<?x?x?x?x?xf32> into tensor<?x?x?xf32>
%1 = linalg.tensor_reshape %0 [[0, 1], [2]]
%1 = linalg.tensor_collapse_shape %0 [[0, 1], [2]]
: tensor<?x?x?xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: collapsing_tensor_reshapes
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
// CHECK-NOT: linalg.tensor_reshape
// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
// CHECK-NOT: linalg.tensor_collapse_shape
// -----
func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>)
-> tensor<f32> {
%0 = linalg.tensor_reshape %arg0 [[0, 1, 2]]
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2]]
: tensor<1x1x1xf32> into tensor<1xf32>
%1 = linalg.tensor_reshape %0 [] : tensor<1xf32> into tensor<f32>
%1 = linalg.tensor_collapse_shape %0 [] : tensor<1xf32> into tensor<f32>
return %1 : tensor<f32>
}
// CHECK-LABEL: collapsing_tensor_reshapes_to_zero
// CHECK: linalg.tensor_reshape %{{.*}} []
// CHECK: linalg.tensor_collapse_shape %{{.*}} []
// CHECK-SAME: tensor<1x1x1xf32> into tensor<f32>
// -----
func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>)
-> memref<f32> {
%0 = linalg.reshape %arg0 [[0, 1, 2]]
%0 = linalg.collapse_shape %arg0 [[0, 1, 2]]
: memref<1x1x1xf32> into memref<1xf32>
%1 = linalg.reshape %0 [] : memref<1xf32> into memref<f32>
%1 = linalg.collapse_shape %0 [] : memref<1xf32> into memref<f32>
return %1 : memref<f32>
}
// CHECK-LABEL: collapsing_memref_reshapes_to_zero
// CHECK: linalg.reshape %{{.*}} []
// CHECK: linalg.collapse_shape %{{.*}} []
// CHECK-SAME: memref<1x1x1xf32> into memref<f32>
// -----
func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>) -> tensor<?x6x4x?x5xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0, 1], [2]]
%0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]]
: tensor<?x?xf32> into tensor<?x4x?xf32>
%1 = linalg.tensor_reshape %0 [[0, 1], [2], [3, 4]]
%1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4]]
: tensor<?x4x?xf32> into tensor<?x6x4x?x5xf32>
return %1 : tensor<?x6x4x?x5xf32>
}
// CHECK-LABEL: expanding_tensor_reshapes
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
// CHECK-NOT: linalg.tensor_reshape
// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
// CHECK-NOT: linalg.tensor_expand_shape
// -----
func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>) -> memref<?x?xf32>
{
%0 = linalg.reshape %arg0 [[0, 1], [2], [3, 4]]
%0 = linalg.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
: memref<?x?x?x?x?xf32> into memref<?x?x?xf32>
%1 = linalg.reshape %0 [[0, 1], [2]]
%1 = linalg.collapse_shape %0 [[0, 1], [2]]
: memref<?x?x?xf32> into memref<?x?xf32>
return %1 : memref<?x?xf32>
}
// CHECK-LABEL: collapsing_memref_reshapes
// CHECK: linalg.reshape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
// CHECK-NOT: linalg.reshape
// CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
// CHECK-NOT: linalg.collapse_shape
// -----
func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>) -> memref<?x6x4x5x?xf32>
{
%0 = linalg.reshape %arg0 [[0, 1], [2]]
%0 = linalg.expand_shape %arg0 [[0, 1], [2]]
: memref<?x?xf32> into memref<?x4x?xf32>
%1 = linalg.reshape %0 [[0, 1], [2], [3, 4]]
%1 = linalg.expand_shape %0 [[0, 1], [2], [3, 4]]
: memref<?x4x?xf32> into memref<?x6x4x5x?xf32>
return %1 : memref<?x6x4x5x?xf32>
}
// CHECK-LABEL: expanding_memref_reshapes
// CHECK: linalg.reshape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
// CHECK-NOT: linalg.reshape
// CHECK: linalg.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
// CHECK-NOT: linalg.expand_shape
// -----
func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor<f32>)
-> tensor<1x1x1xf32> {
%0 = linalg.tensor_reshape %arg0 [] : tensor<f32> into tensor<1xf32>
%1 = linalg.tensor_reshape %0 [[0, 1, 2]]
%0 = linalg.tensor_expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
%1 = linalg.tensor_expand_shape %0 [[0, 1, 2]]
: tensor<1xf32> into tensor<1x1x1xf32>
return %1 : tensor<1x1x1xf32>
}
// CHECK-LABEL: expanding_tensor_reshapes_to_zero
// CHECK: linalg.tensor_reshape %{{.*}} []
// CHECK: linalg.tensor_expand_shape %{{.*}} []
// CHECK-SAME: tensor<f32> into tensor<1x1x1xf32>
// -----
func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref<f32>)
-> memref<1x1x1xf32> {
%0 = linalg.reshape %arg0 [] : memref<f32> into memref<1xf32>
%1 = linalg.reshape %0 [[0, 1, 2]]
%0 = linalg.expand_shape %arg0 [] : memref<f32> into memref<1xf32>
%1 = linalg.expand_shape %0 [[0, 1, 2]]
: memref<1xf32> into memref<1x1x1xf32>
return %1 : memref<1x1x1xf32>
}
// CHECK-LABEL: expanding_memref_reshapes_to_zero
// CHECK: linalg.reshape %{{.*}} []
// CHECK: linalg.expand_shape %{{.*}} []
// CHECK-SAME: memref<f32> into memref<1x1x1xf32>
// -----
func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0, 1], [2]]
%0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]]
: tensor<12x4xf32> into tensor<3x4x4xf32>
%1 = linalg.tensor_reshape %0 [[0, 1], [2]]
%1 = linalg.tensor_collapse_shape %0 [[0, 1], [2]]
: tensor<3x4x4xf32> into tensor<12x4xf32>
return %1 : tensor<12x4xf32>
}
// CHECK-LABEL: @fold_tensor_reshape
// CHECK-NOT: linalg.tensor_reshape
// CHECK-NOT: linalg.{{.*}}shape
// -----
func @fold_tensor_reshape_dynamic(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0, 1], [2]]
%0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]]
: tensor<?x?xf32> into tensor<?x4x?xf32>
%1 = linalg.tensor_reshape %0 [[0, 1], [2]]
%1 = linalg.tensor_collapse_shape %0 [[0, 1], [2]]
: tensor<?x4x?xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: @fold_tensor_reshape_dynamic
// CHECK-NOT: linalg.tensor_reshape
// CHECK-NOT: linalg.{{.*}}_shape
// -----
func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32>
{
%0 = linalg.reshape %arg0 [[0, 1], [2]]
%0 = linalg.expand_shape %arg0 [[0, 1], [2]]
: memref<12x4xf32> into memref<3x4x4xf32>
%1 = linalg.reshape %0 [[0, 1], [2]]
%1 = linalg.collapse_shape %0 [[0, 1], [2]]
: memref<3x4x4xf32> into memref<12x4xf32>
return %1 : memref<12x4xf32>
}
// CHECK-LABEL: @fold_memref_reshape
// CHECK-NOT: linalg.reshape
// CHECK-NOT: linalg.{{.*}}_shape
// -----
func @fold_memref_reshape_dynamic(%arg0 : memref<?x?xf32>) -> memref<?x?xf32>
{
%0 = linalg.reshape %arg0 [[0, 1], [2]]
%0 = linalg.expand_shape %arg0 [[0, 1], [2]]
: memref<?x?xf32> into memref<?x4x?xf32>
%1 = linalg.reshape %0 [[0, 1], [2]]
%1 = linalg.collapse_shape %0 [[0, 1], [2]]
: memref<?x4x?xf32> into memref<?x?xf32>
return %1 : memref<?x?xf32>
}
// CHECK-LABEL: @fold_memref_reshape_dynamic
// CHECK-NOT: linalg.reshape
// CHECK-NOT: linalg.{{.*}}_shape
// -----
func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) -> tensor<24x5x42x8xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0, 1, 2, 3, 4, 5, 6]]
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]]
: tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32>
%1 = linalg.tensor_reshape %0 [[0, 1, 2, 3]]
%1 = linalg.tensor_expand_shape %0 [[0, 1, 2, 3]]
: tensor<40320xf32> into tensor<24x5x42x8xf32>
return %1 : tensor<24x5x42x8xf32>
}
// CHECK: func @reshape_collapse
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32>
// CHECK: %[[RESULT:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK: %[[RESULT:.+]] = linalg.tensor_collapse_shape %[[ARG0]]
// CHECK-SAME: [0, 1, 2], [3], [4, 5], [6]
// CHECK: return %[[RESULT]]
@ -232,15 +232,15 @@ func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) -> tensor<24x5x42x8xf3
func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>) -> tensor<2x3x4x5x6x7x8xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0, 1, 2, 3]]
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2, 3]]
: tensor<24x5x42x8xf32> into tensor<40320xf32>
%1 = linalg.tensor_reshape %0 [[0, 1, 2, 3, 4, 5, 6]]
%1 = linalg.tensor_expand_shape %0 [[0, 1, 2, 3, 4, 5, 6]]
: tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32>
return %1 : tensor<2x3x4x5x6x7x8xf32>
}
// CHECK: func @reshape_expand
// CHECK-SAME: %[[ARG0:.+]]: tensor<24x5x42x8xf32>
// CHECK: %[[RESULT:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK: %[[RESULT:.+]] = linalg.tensor_expand_shape %[[ARG0]]
// CHECK-SAME: [0, 1, 2], [3], [4, 5], [6]
// CHECK: return %[[RESULT]]
@ -248,84 +248,84 @@ func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>) -> tensor<2x3x4x5x6x7x8xf32>
func @expand_reshape_1D(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0, 1, 2, 3]]
%0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2, 3]]
: tensor<2048xf32> into tensor<1x4x1x512xf32>
%1 = linalg.tensor_reshape %0 [[0, 1, 2], [3]]
%1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3]]
: tensor<1x4x1x512xf32> into tensor<4x512xf32>
return %1 : tensor<4x512xf32>
}
// CHECK: func @expand_reshape_1D
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1]]
// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1]]
// CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32>
// -----
func @fold_reshape_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0, 1, 2], [3]]
%0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2], [3]]
: tensor<4x512xf32> into tensor<1x4x1x512xf32>
%1 = linalg.tensor_reshape %0 [[0, 1, 2, 3]]
%1 = linalg.tensor_collapse_shape %0 [[0, 1, 2, 3]]
: tensor<1x4x1x512xf32> into tensor<2048xf32>
return %1 : tensor<2048xf32>
}
// CHECK: func @fold_reshape_1D
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1]]
// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1]]
// CHECK-SAME: tensor<4x512xf32> into tensor<2048xf32>
// -----
func @fold_reshape_unit_dims(%arg0 : tensor<2048x1x1xf32>) -> tensor<4x512x1x1xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0, 1, 2, 3], [4], [5]]
%0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2, 3], [4], [5]]
: tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32>
%1 = linalg.tensor_reshape %0 [[0, 1, 2], [3], [4], [5]]
%1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3], [4], [5]]
: tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32>
return %1 : tensor<4x512x1x1xf32>
}
// CHECK: func @fold_reshape_unit_dims
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1], [2], [3]]
// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]]
// CHECK-SAME: tensor<2048x1x1xf32> into tensor<4x512x1x1xf32>
// -----
func @expand_reshape_unit_dims(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]]
%0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]]
: tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32>
%1 = linalg.tensor_reshape %0 [[0, 1, 2], [3, 4], [5], [6, 7], [8]]
%1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3, 4], [5], [6, 7], [8]]
: tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
return %1 : tensor<4x512x1x512x4xf32>
}
// CHECK: func @expand_reshape_unit_dims
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1], [2], [3, 4]]
// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]]
// CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32>
// -----
func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0, 1, 2]]
%0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2]]
: tensor<2xf32> into tensor<2x1x1xf32>
%1 = linalg.tensor_reshape %0 [[0], [1, 2]]
%1 = linalg.tensor_collapse_shape %0 [[0], [1, 2]]
: tensor<2x1x1xf32> into tensor<2x1xf32>
return %1 : tensor<2x1xf32>
}
// CHECK: func @fold_reshape_trailing_unit_dims
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1]]
// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1]]
// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32>
// -----
func @collapse_reshape_unit_dims_dynamic(%arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>) -> tensor<?x?x?x?xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]]
%0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]]
: tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x1x1x?x?xf32>
%1 = linalg.tensor_reshape %0 [[0], [1], [2, 3, 4], [5]]
%1 = linalg.tensor_collapse_shape %0 [[0], [1], [2, 3, 4], [5]]
: tensor<?x?x1x1x?x?xf32> into tensor<?x?x?x?xf32>
return %1 : tensor<?x?x?x?xf32>
}
// CHECK: func @collapse_reshape_unit_dims_dynamic
// CHECK: linalg.tensor_reshape
// CHECK: linalg.tensor_collapse_shape
// CHECK-SAME: [0], [1, 2], [3, 4, 5], [6, 7, 8]
// CHECK-SAME: tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x?x?xf32>
@ -333,72 +333,72 @@ func @collapse_reshape_unit_dims_dynamic(%arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>)
func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0, 1, 2]]
%0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2]]
: tensor<2xf32> into tensor<2x1x1xf32>
%1 = linalg.tensor_reshape %0 [[0], [1, 2]]
%1 = linalg.tensor_collapse_shape %0 [[0], [1, 2]]
: tensor<2x1x1xf32> into tensor<2x1xf32>
return %1 : tensor<2x1xf32>
}
// CHECK: func @fold_reshape_trailing_unit_dims
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1]]
// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1]]
// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32>
// -----
func @fold_reshape_trailing_unit_dims_dynamic(%arg0: tensor<1x1x?x1x1x1xf32>) -> tensor<?xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0, 1, 2], [3], [4], [5]]
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]]
: tensor<1x1x?x1x1x1xf32> into tensor<?x1x1x1xf32>
%1 = linalg.tensor_reshape %0 [[0, 1, 2, 3]]
%1 = linalg.tensor_collapse_shape %0 [[0, 1, 2, 3]]
: tensor<?x1x1x1xf32> into tensor<?xf32>
return %1 : tensor<?xf32>
}
// CHECK: func @fold_reshape_trailing_unit_dims_dynamic
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]]
// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]]
// CHECK-SAME: tensor<1x1x?x1x1x1xf32> into tensor<?xf32>
// -----
func @no_fold_reshapes(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0], [1], [2, 3]]
%0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3]]
: tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
%1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]]
%1 = linalg.tensor_collapse_shape %0 [[0], [1, 2, 3]]
: tensor<?x?x1x?xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func @no_fold_reshapes
// CHECK: linalg.tensor_reshape
// CHECK: linalg.tensor_reshape
// CHECK: linalg.tensor_expand_shape
// CHECK: linalg.tensor_collapse_shape
// -----
func @no_fold_reshape_incompatible(%arg0 : tensor<4x6x8xf32>) -> tensor<2x6x16xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0, 1], [2, 3], [4]]
%0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2, 3], [4]]
: tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32>
%1 = linalg.tensor_reshape %0 [[0], [1, 2], [3, 4]]
%1 = linalg.tensor_collapse_shape %0 [[0], [1, 2], [3, 4]]
: tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32>
return %1 : tensor<2x6x16xf32>
}
// CHECK-LABEL: func @no_fold_reshape_incompatible
// CHECK: linalg.tensor_reshape
// CHECK: linalg.tensor_reshape
// CHECK: linalg.tensor_expand_shape
// CHECK: linalg.tensor_collapse_shape
// -----
func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> {
%0 = linalg.tensor_reshape %arg0 [[0], [1], [2, 3]]
%0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3]]
: tensor<3x2x2xf32> into tensor<3x2x2x1xf32>
%1 = linalg.tensor_reshape %0 [[0, 1, 2], [3]]
%1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3]]
: tensor<3x2x2x1xf32> into tensor<12x1xf32>
return %1 : tensor<12x1xf32>
}
// CHECK: func @no_fold_reshape_empty_expr
// CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32>
// CHECK: %[[RARG0:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK: %[[RARG0:.+]] = linalg.tensor_expand_shape %[[ARG0]]
// CHECK-SAME: [0], [1], [2, 3]
// CHECK: %[[RES:.+]] = linalg.tensor_reshape %[[RARG0]]
// CHECK: %[[RES:.+]] = linalg.tensor_collapse_shape %[[RARG0]]
// CHECK-SAME: [0, 1, 2], [3]
// CHECK: return %[[RES:.+]] : tensor<12x1xf32>
@ -436,49 +436,49 @@ func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf
func @reshape_splat_constant_int32() -> tensor<2x4x2xi32>
{
%c0 = constant dense<42> : tensor<2x8xi32>
%0 = linalg.tensor_reshape %c0 [[0], [1, 2]]
%0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]]
: tensor<2x8xi32> into tensor<2x4x2xi32>
return %0 : tensor<2x4x2xi32>
}
// CHECK-LABEL: @reshape_splat_constant_int32
// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi32>
// CHECK-NOT: linalg.tensor_reshape
// CHECK-NOT: linalg.tensor_expand_shape
// CHECK: return %[[CST]]
func @reshape_splat_constant_int16() -> tensor<2x4x2xi16>
{
%c0 = constant dense<42> : tensor<2x8xi16>
%0 = linalg.tensor_reshape %c0 [[0], [1, 2]]
%0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]]
: tensor<2x8xi16> into tensor<2x4x2xi16>
return %0 : tensor<2x4x2xi16>
}
// CHECK-LABEL: @reshape_splat_constant_int16
// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi16>
// CHECK-NOT: linalg.tensor_reshape
// CHECK-NOT: linalg.tensor_expand_shape
// CHECK: return %[[CST]]
func @reshape_splat_constant_float32() -> tensor<2x4x2xf32>
{
%c0 = constant dense<42.0> : tensor<2x8xf32>
%0 = linalg.tensor_reshape %c0 [[0], [1, 2]]
%0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]]
: tensor<2x8xf32> into tensor<2x4x2xf32>
return %0 : tensor<2x4x2xf32>
}
// CHECK-LABEL: @reshape_splat_constant_float32
// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf32>
// CHECK-NOT: linalg.tensor_reshape
// CHECK-NOT: linalg.tensor_expand_shape
// CHECK: return %[[CST]]
func @reshape_splat_constant_float64() -> tensor<2x4x2xf64>
{
%c0 = constant dense<42.0> : tensor<2x8xf64>
%0 = linalg.tensor_reshape %c0 [[0], [1, 2]]
%0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]]
: tensor<2x8xf64> into tensor<2x4x2xf64>
return %0 : tensor<2x4x2xf64>
}
// CHECK-LABEL: @reshape_splat_constant_float64
// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf64>
// CHECK-NOT: linalg.tensor_reshape
// CHECK-NOT: linalg.tensor_expand_shape
// CHECK: return %[[CST]]
// -----
@ -733,7 +733,7 @@ func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
%0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32>
%1 = linalg.tensor_reshape %0 [[0, 1], [2], [3, 4, 5]]
%1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4, 5]]
: tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
return %1 : tensor<2x3x5x4x?x7xf32>
}
@ -748,7 +748,7 @@ func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
%0 = linalg.init_tensor [2, 3, 5, 4, %arg0, 7] : tensor<2x3x5x4x?x7xf32>
%1 = linalg.tensor_reshape %0 [[0, 1], [2], [3, 4, 5]]
%1 = linalg.tensor_collapse_shape %0 [[0, 1], [2], [3, 4, 5]]
: tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
return %1 : tensor<6x5x?xf32>
}
@ -898,7 +898,7 @@ func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
%c1 = constant 1 : index
%c3 = constant 3 : index
%c4 = constant 4 : index
%0 = linalg.tensor_reshape %arg0 [[0, 1], [2], [3, 4, 5]]
%0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2], [3, 4, 5]]
: tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
%1 = memref.dim %0, %c1 : tensor<2x3x5x4x?x7xf32>
%2 = memref.dim %0, %c3 : tensor<2x3x5x4x?x7xf32>
@ -921,7 +921,7 @@ func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index)
{
%c1 = constant 1 : index
%c2 = constant 2 : index
%0 = linalg.tensor_reshape %arg0 [[0, 1], [2], [3, 4, 5]]
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4, 5]]
: tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
%1 = memref.dim %0, %c1 : tensor<6x5x?xf32>
%2 = memref.dim %0, %c2 : tensor<6x5x?xf32>
@ -979,7 +979,7 @@ func @fold_fill_reshape() -> tensor<6x4xf32> {
%init = linalg.init_tensor [1, 2, 3, 4] : tensor<1x2x3x4xf32>
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor<6x4xf32>, f32 -> tensor<6x4xf32>
%fill = linalg.fill(%init, %zero) : tensor<1x2x3x4xf32>, f32 -> tensor<1x2x3x4xf32>
%reshape = linalg.tensor_reshape %fill [[0, 1, 2], [3]]
%reshape = linalg.tensor_collapse_shape %fill [[0, 1, 2], [3]]
: tensor<1x2x3x4xf32> into tensor<6x4xf32>
// CHECK: return %[[FILL]] : tensor<6x4xf32>
return %reshape : tensor<6x4xf32>
@ -991,10 +991,10 @@ func @fold_fill_reshape() -> tensor<6x4xf32> {
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?x?xf32>
func @fold_fill_reshape_dynamic(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32> {
%zero = constant 0.0 : f32
// CHECK: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]]
%0 = linalg.fill(%arg0, %zero) : tensor<?x?x?x?x?xf32>, f32 -> tensor<?x?x?x?x?xf32>
// CHECK: %[[RESULT:.+]] = linalg.fill(%[[RESHAPE]], %{{.+}})
%1 = linalg.tensor_reshape %0 [[0, 1, 2], [3, 4]]
%1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3, 4]]
: tensor<?x?x?x?x?xf32> into tensor<?x?xf32>
// CHECK: return %[[RESULT]]
return %1 : tensor<?x?xf32>

View File

@ -19,7 +19,7 @@ func @detensor_simple(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> att
// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]]
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]]
// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_collapse_shape %[[new_tensor_res]]
// CHECK: return %[[reshaped_tensor_res]]
func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
@ -61,7 +61,7 @@ func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32
// CHECK: %[[detensored_res2:.*]] = mulf %[[arg1_val2]], %[[detensored_res]]
// CHECK: %[[detensored_res3:.*]] = divf %[[detensored_res]], %[[detensored_res2]]
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]]
// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_collapse_shape %[[new_tensor_res]]
// CHECK: return %[[reshaped_tensor_res]]
func @detensor_multiple_ops(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
@ -83,7 +83,7 @@ func @detensor_multiple_ops(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f3
// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]]
// CHECK: %[[detensored_res2:.*]] = mulf %[[detensored_res]], %[[arg2_val]]
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res2]]
// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_collapse_shape %[[new_tensor_res]]
// CHECK: return %[[reshaped_tensor_res]]
func @detensor_foreign_op(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
@ -103,5 +103,5 @@ func @detensor_foreign_op(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32>
// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
// CHECK: %[[detensored_res:.*]] = "foreign.do_something"(%[[arg1_val]], %[[arg2_val]])
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]]
// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_collapse_shape %[[new_tensor_res]]
// CHECK: return %[[reshaped_tensor_res]]

View File

@ -10,10 +10,10 @@
func @main() -> (tensor<i32>) attributes {} {
%c0 = constant 0 : i32
%0 = tensor.from_elements %c0 : tensor<1xi32>
%reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor<i32>
%reshaped0 = linalg.tensor_collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
%c10 = constant 10 : i32
%1 = tensor.from_elements %c10 : tensor<1xi32>
%reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
%reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor<i32>
br ^bb1(%reshaped0 : tensor<i32>)
^bb1(%2: tensor<i32>): // 2 preds: ^bb0, ^bb2
@ -55,7 +55,7 @@ func @main() -> (tensor<i32>) attributes {} {
// CHECK-NEXT: br ^[[bb3:.*]](%{{.*}} : i32)
// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<1xi32>
// CHECK-NEXT: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
// CHECK-NEXT: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
// CHECK-NEXT: return %{{.*}}
// CHECK-NEXT: }
@ -74,10 +74,10 @@ func @main() -> (tensor<i32>) attributes {} {
func @main() -> (tensor<i32>) attributes {} {
%c0 = constant 0 : i32
%0 = tensor.from_elements %c0 : tensor<1xi32>
%reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor<i32>
%reshaped0 = linalg.tensor_collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
%c10 = constant 10 : i32
%1 = tensor.from_elements %c10 : tensor<1xi32>
%reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
%reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor<i32>
br ^bb1(%reshaped0 : tensor<i32>)
^bb1(%2: tensor<i32>): // 2 preds: ^bb0, ^bb2
@ -124,7 +124,7 @@ func @main() -> (tensor<i32>) attributes {} {
// CHECK-NEXT: br ^[[bb4:.*]](%{{.*}} : i32)
// CHECK-NEXT: ^[[bb4]](%{{.*}}: i32)
// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<1xi32>
// CHECK-NEXT: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
// CHECK-NEXT: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
// CHECK-NEXT: return %{{.*}}
// CHECK-NEXT: }
@ -140,10 +140,10 @@ func @main() -> (tensor<i32>) attributes {} {
func @main() -> (tensor<i32>) attributes {} {
%c0 = constant 0 : i32
%0 = tensor.from_elements %c0 : tensor<1xi32>
%reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor<i32>
%reshaped0 = linalg.tensor_collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
%c10 = constant 10 : i32
%1 = tensor.from_elements %c10 : tensor<1xi32>
%reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
%reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor<i32>
br ^bb1(%reshaped0 : tensor<i32>)
^bb1(%2: tensor<i32>): // 2 preds: ^bb0, ^bb2
@ -164,7 +164,7 @@ func @main() -> (tensor<i32>) attributes {} {
^bb2(%6: tensor<i32>): // pred: ^bb1
%12 = tensor.from_elements %c10 : tensor<1xi32>
%reshaped12 = linalg.tensor_reshape %12 [] : tensor<1xi32> into tensor<i32>
%reshaped12 = linalg.tensor_collapse_shape %12 [] : tensor<1xi32> into tensor<i32>
%7 = linalg.init_tensor [] : tensor<i32>
%8 = linalg.generic #attrs
ins(%6, %reshaped12 : tensor<i32>, tensor<i32>)
@ -191,6 +191,6 @@ func @main() -> (tensor<i32>) attributes {} {
// CHECK-NEXT: br ^[[bb3:.*]](%{{.*}} : i32)
// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<1xi32>
// CHECK-NEXT: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
// CHECK-NEXT: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
// CHECK-NEXT: return %{{.*}}
// CHECK-NEXT: }

View File

@ -12,7 +12,7 @@
func @main(%farg0 : tensor<i32>) -> (tensor<i1>) attributes {} {
%c10 = constant 10 : i32
%1 = tensor.from_elements %c10 : tensor<1xi32>
%reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
%reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor<i32>
%3 = linalg.init_tensor [] : tensor<i1>
%4 = linalg.generic #attrs
ins(%farg0, %reshaped1 : tensor<i32>, tensor<i32>)
@ -30,7 +30,7 @@ func @main(%farg0 : tensor<i32>) -> (tensor<i1>) attributes {} {
// DET-ALL-NEXT: tensor.extract %{{.*}}[]
// DET-ALL-NEXT: cmpi slt, %{{.*}}, %{{.*}}
// DET-ALL-NEXT: tensor.from_elements %{{.*}}
// DET-ALL-NEXT: linalg.tensor_reshape %{{.*}}
// DET-ALL-NEXT: linalg.tensor_collapse_shape %{{.*}}
// DET-ALL-NEXT: return %{{.*}} : tensor<i1>
// DET-ALL-NEXT: }

View File

@ -52,7 +52,7 @@ func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attributes {
// DET-ALL: br ^[[bb1]](%{{.*}} : i32)
// DET-ALL: ^[[bb3]](%{{.*}}: i32)
// DET-ALL: tensor.from_elements {{.*}}
// DET-ALL: linalg.tensor_reshape {{.*}}
// DET-ALL: linalg.tensor_collapse_shape {{.*}}
// DET-ALL: return %{{.*}} : tensor<i32>
// Test detensoring only ops involed in control-flow.
@ -69,5 +69,5 @@ func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attributes {
// DET-CF: br ^[[bb1]](%{{.*}} : i32)
// DET-CF: ^[[bb3]](%{{.*}}: i32)
// DET-CF: tensor.from_elements %{{.*}} : tensor<1xi32>
// DET-CF: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
// DET-CF: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
// DET-CF: return %{{.*}} : tensor<i32>

View File

@ -80,7 +80,7 @@ func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attribute
// DET-ALL: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
// DET-ALL: ^[[bb2]](%{{.*}}: i32)
// DET-ALL: tensor.from_elements %{{.*}} : tensor<1xi32>
// DET-ALL: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
// DET-ALL: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
// DET-ALL: linalg.init_tensor [10] : tensor<10xi32>
// DET-ALL: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
// DET-ALL: ^bb0(%{{.*}}: i32, %{{.*}}: i32):
@ -89,11 +89,11 @@ func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attribute
// DET-ALL: br ^[[bb1]](%{{.*}} : tensor<10xi32>)
// DET-ALL: ^[[bb3]](%{{.*}}: i32)
// DET-ALL: tensor.from_elements %{{.*}} : tensor<1xi32>
// DET-ALL: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
// DET-ALL: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
// DET-ALL: return %{{.*}} : tensor<i32>
// DET-ALL: }
// Try to detensor pure control-flow. However, that fails since the potential
// Try to detensor pure control-flow. However, that fails since the potential
// detensorable component contains some ops that cannot be detensored.
//
// DET-CF-LABEL: func @main

View File

@ -10,10 +10,10 @@
func @main() -> () attributes {} {
%c0 = constant 0 : i32
%0 = tensor.from_elements %c0 : tensor<1xi32>
%reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor<i32>
%reshaped0 = linalg.tensor_collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
%c10 = constant 10 : i32
%1 = tensor.from_elements %c10 : tensor<1xi32>
%reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
%reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor<i32>
br ^bb1(%reshaped0 : tensor<i32>)
^bb1(%2: tensor<i32>): // 2 preds: ^bb0, ^bb2

View File

@ -23,11 +23,11 @@ func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %shape: tensor<?x1x?x1x?xf3
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @drop_one_trip_loops
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1], [2]]
// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1], [2]]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP2]], #[[$MAP3]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
// -----
@ -101,7 +101,7 @@ func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
}
// CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
// CHECK-LABEL: func @drop_all_loops
// CHECK: linalg.tensor_reshape %{{.*}} []
// CHECK: linalg.tensor_collapse_shape %{{.*}} []
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
// CHECK-SAME: iterator_types = []
@ -162,7 +162,7 @@ func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>, %shape: tensor<5xf3
// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func @leading_dim_1_canonicalization
// CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1]]
// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1]]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP1]]]
// CHECK-SAME: iterator_types = ["parallel"]
@ -183,8 +183,8 @@ func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>, %shape: tensor<5xf3
func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>, %shape : tensor<5x5xf32>) -> tensor<5x5xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0, 1]] : tensor<5xf32> into tensor<1x5xf32>
%1 = linalg.tensor_reshape %arg1 [[0, 1]] : tensor<5xf32> into tensor<5x1xf32>
%0 = linalg.tensor_expand_shape %arg0 [[0, 1]] : tensor<5xf32> into tensor<1x5xf32>
%1 = linalg.tensor_expand_shape %arg1 [[0, 1]] : tensor<5xf32> into tensor<5x1xf32>
%2 = linalg.generic #trait
ins(%0, %1 : tensor<1x5xf32>, tensor<5x1xf32>)
outs(%shape : tensor<5x5xf32>) {
@ -198,11 +198,11 @@ func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>, %shape : tens
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @broadcast_test
// CHECK-NOT: linalg.tensor_reshape
// CHECK-NOT: linalg.tensor_{{.*}}shape
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-NOT: linalg.tensor_reshape
// CHECK-NOT: linalg.tensor_{{.*}}shape
// -----
@ -231,7 +231,7 @@ func @broadcast_scalar(%arg0 : tensor<1x1xf32>, %shape : tensor<?x?xf32>) -> ten
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @broadcast_scalar
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1xf32>
// CHECK: %[[A:.*]] = linalg.tensor_reshape %[[ARG0]] []
// CHECK: %[[A:.*]] = linalg.tensor_collapse_shape %[[ARG0]] []
// CHECK-SAME: tensor<1x1xf32> into tensor<f32>
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
@ -251,7 +251,7 @@ func @fold_unit_dim_tensor_reshape_op(%arg0 : tensor<5xf32>) -> tensor<2x5xf32>
^bb0(%arg1: f32, %arg2: f32): // no predecessors
linalg.yield %arg1 : f32
} -> tensor<1x2x5xf32>
%3 = linalg.tensor_reshape %2 [[0, 1], [2]]
%3 = linalg.tensor_collapse_shape %2 [[0, 1], [2]]
: tensor<1x2x5xf32> into tensor<2x5xf32>
return %3 : tensor<2x5xf32>
}
@ -283,7 +283,7 @@ func @fold_unit_dim_for_init_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32>
// CHECK: func @fold_unit_dim_for_init_tensor
// CHECK: %[[INPUT_RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} {{\[}}[0, 1]] : tensor<1x1000xf32> into tensor<1000xf32>
// CHECK: %[[INPUT_RESHAPE:.+]] = linalg.tensor_collapse_shape %{{.+}} {{\[}}[0, 1]] : tensor<1x1000xf32> into tensor<1000xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [] : tensor<f32>
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor<f32>, f32 -> tensor<f32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
@ -291,7 +291,7 @@ func @fold_unit_dim_for_init_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32>
// CHECK-SAME: iterator_types = ["reduction"]
// CHECK-SAME: ins(%[[INPUT_RESHAPE]] : tensor<1000xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<f32>)
// CHECK: %[[GENERIC_RESHAPE:.+]] = linalg.tensor_reshape %[[GENERIC]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[GENERIC_RESHAPE:.+]] = linalg.tensor_expand_shape %[[GENERIC]] [] : tensor<f32> into tensor<1xf32>
// CHECK: return %[[GENERIC_RESHAPE:.+]] : tensor<1xf32>
@ -314,11 +314,11 @@ func @fold_subtensor(
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?x?x1x1xf32>
// CHECK: %[[SUBTENSOR1:.+]] = subtensor %[[ARG0]]
// CHECK-SAME: to tensor<?x?x?xf32>
// CHECK: %[[RESULT1:.+]] = linalg.tensor_reshape %[[SUBTENSOR1]]
// CHECK: %[[RESULT1:.+]] = linalg.tensor_expand_shape %[[SUBTENSOR1]]
// CHECK-SAME: [0, 1], [2], [3, 4, 5, 6]
// CHECK: %[[SUBTENSOR2:.+]] = subtensor %[[ARG1]]
// CHECK-SAME: to tensor<?x?x?xf32>
// CHECK: %[[RESULT2:.+]] = linalg.tensor_reshape %[[SUBTENSOR2]]
// CHECK: %[[RESULT2:.+]] = linalg.tensor_expand_shape %[[SUBTENSOR2]]
// CHECK-SAME: [0, 1], [2], [3, 4, 5, 6]
// CHECK: return %[[RESULT1]], %[[RESULT2]]
@ -346,7 +346,7 @@ func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> {
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)>
// CHECK: func @unit_dim_for_reduction
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x?xf32>
// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%{{.+}}] : tensor<?xf32>
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %{{.+}})
// CHECK: %[[RESULT:.+]] = linalg.generic
@ -354,7 +354,7 @@ func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> {
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?xf32>)
// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] {{\[}}[0, 1]]
// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_expand_shape %[[RESULT]] {{\[}}[0, 1]]
// CHECK: return %[[RESULT_RESHAPE]]
// -----
@ -380,7 +380,7 @@ func @unit_dim_for_reduction_keep_one(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1x
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)>
// CHECK: func @unit_dim_for_reduction_keep_one
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x1xf32>
// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1] : tensor<1xf32>
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %{{.+}})
// CHECK: %[[RESULT:.+]] = linalg.generic
@ -388,7 +388,7 @@ func @unit_dim_for_reduction_keep_one(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1x
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?x1xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<1xf32>)
// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] {{\[}}[0, 1]]
// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_expand_shape %[[RESULT]] {{\[}}[0, 1]]
// CHECK: return %[[RESULT_RESHAPE]]
// -----
@ -415,7 +415,7 @@ func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x1xf32
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)>
// CHECK: func @unit_dim_for_reduction_inner
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x1xf32>
// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] {{\[}}[0, 1], [2, 3]]
// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%{{.+}}] : tensor<?xf32>
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %{{.+}})
// CHECK: %[[RESULT:.+]] = linalg.generic
@ -423,7 +423,7 @@ func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x1xf32
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?xf32>)
// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] {{\[}}[0, 1]]
// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_expand_shape %[[RESULT]] {{\[}}[0, 1]]
// CHECK: return %[[RESULT_RESHAPE]]
// -----
@ -435,7 +435,7 @@ func @subtensor_unit_dims(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> {
// CHECK-LABEL: func @subtensor_unit_dims
// CHECK: %[[SUBTENSOR:.+]] = subtensor
// CHECK-SAME: tensor<1x3xf32> to tensor<f32>
// CHECK: %[[RESULT:.+]] = linalg.tensor_reshape %[[SUBTENSOR]] []
// CHECK: %[[RESULT:.+]] = linalg.tensor_expand_shape %[[SUBTENSOR]] []
// CHECK: return %[[RESULT]]
// -----
@ -445,7 +445,7 @@ func @subtensor_insert_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>)
return %0 : tensor<1x3xf32>
}
// CHECK-LABEL: func @subtensor_insert_unit_dims
// CHECK: %[[RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} []
// CHECK: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %{{.+}} []
// CHECK: %[[RESULT:.+]] = subtensor_insert %[[RESHAPE]]
// CHECK-SAME: tensor<f32> into tensor<1x3xf32>
// CHECK: return %[[RESULT]]

View File

@ -5,14 +5,14 @@
// CHECK-LABEL: func @reshape
// CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>)
// CHECK: %[[RI:.*]] = linalg.tensor_reshape %[[INIT]] {{\[}}[0, 1], [2]] : tensor<?x112x16xf32> into tensor<?x16xf32>
// CHECK: %[[RI:.*]] = linalg.tensor_collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] : tensor<?x112x16xf32> into tensor<?x16xf32>
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<?x16xf32>)
// CHECK: %[[RR:.*]] = linalg.tensor_reshape %[[R]] {{\[}}[0, 1], [2]] : tensor<?x16xf32> into tensor<?x112x16xf32>
// CHECK: %[[RR:.*]] = linalg.tensor_expand_shape %[[R]] {{\[}}[0, 1], [2]] : tensor<?x16xf32> into tensor<?x112x16xf32>
// CHECK: return %[[RR]] : tensor<?x112x16xf32>
func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf32>) -> tensor<?x112x16xf32> {
%0 = linalg.tensor_reshape %A [[0, 1], [2]]
%0 = linalg.tensor_expand_shape %A [[0, 1], [2]]
: tensor<?x16xf32> into tensor<?x112x16xf32>
%2 = linalg.generic {indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>,
@ -35,17 +35,17 @@ func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf
// CHECK-LABEL: func @reshape_multiple
// CHECK-SAME: (%[[A:.*]]: tensor<12544x16xf32>, %[[B:.*]]: tensor<12544x16xf32>, %[[C:.*]]: tensor<16xf32>)
// CHECK: %[[I:.*]] = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
// CHECK: %[[RI:.*]] = linalg.tensor_reshape %[[I]] {{\[}}[0, 1], [2]] : tensor<112x112x16xf32> into tensor<12544x16xf32>
// CHECK: %[[RI:.*]] = linalg.tensor_collapse_shape %[[I]] {{\[}}[0, 1], [2]] : tensor<112x112x16xf32> into tensor<12544x16xf32>
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] : tensor<12544x16xf32>, tensor<12544x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<12544x16xf32>)
// CHECK: %[[RR:.*]] = linalg.tensor_reshape %[[R]] {{\[}}[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32>
// CHECK: %[[RR:.*]] = linalg.tensor_expand_shape %[[R]] {{\[}}[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32>
// CHECK: return %[[RR]] : tensor<112x112x16xf32>
func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>,
%C: tensor<16xf32>) -> tensor<112x112x16xf32> {
%0 = linalg.tensor_reshape %A [[0, 1], [2]]
%0 = linalg.tensor_expand_shape %A [[0, 1], [2]]
: tensor<12544x16xf32> into tensor<112x112x16xf32>
%1 = linalg.tensor_reshape %B [[0, 1], [2]]
%1 = linalg.tensor_expand_shape %B [[0, 1], [2]]
: tensor<12544x16xf32> into tensor<112x112x16xf32>
%2 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
%3 = linalg.generic {indexing_maps = [
@ -69,11 +69,11 @@ func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>,
// Negative test, since the second source is broadcasted from d1 we cannot merge
// d0 and d1 dimensions
// CHECK-LABEL: func @reshape_negative
// CHECK: linalg.tensor_reshape {{.*}} : tensor<12544x16xf32> into tensor<112x112x16xf32>
// CHECK: linalg.tensor_expand_shape {{.*}} : tensor<12544x16xf32> into tensor<112x112x16xf32>
// CHECK: linalg.generic
// CHECK: } -> tensor<112x112x16xf32>
func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<112x112x16xf32> {
%20 = linalg.tensor_reshape %A [[0, 1], [2]]
%20 = linalg.tensor_expand_shape %A [[0, 1], [2]]
: tensor<12544x16xf32> into tensor<112x112x16xf32>
%21 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
%22 = linalg.generic {indexing_maps = [
@ -96,7 +96,7 @@ func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>,
%cst_6 = constant 1.000000e+00 : f32
%cst_7 = constant 7.000000e+00 : f32
%cst_8 = constant 1.1920929E-7 : f32
%25 = linalg.tensor_reshape %arg0 [[0, 1], [2]]
%25 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]]
: tensor<6x5xi32> into tensor<2x3x5xi32>
%26 = linalg.init_tensor [2, 3, 5] : tensor<2x3x5xf32>
%28 = linalg.generic {
@ -122,5 +122,5 @@ func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>,
// CHECK: %[[OP:.+]] = linalg.generic
// CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : tensor<6x5xi32>, tensor<5xf32>, tensor<5xf32>)
// CHECK-SAME: outs(%{{.+}} : tensor<6x5xf32>)
// CHECK: linalg.tensor_reshape %[[OP]]
// CHECK: linalg.tensor_expand_shape %[[OP]]
// CHECK-SAME: tensor<6x5xf32> into tensor<2x3x5xf32>

View File

@ -348,21 +348,35 @@ func @generic(%arg0: memref<?x?xi4>) {
func @reshape(%arg0: memref<f32>) {
// expected-error @+1 {{expected non-zero memref ranks}}
%0 = linalg.reshape %arg0 [[0]] : memref<f32> into memref<f32>
%0 = linalg.expand_shape %arg0 [[0]] : memref<f32> into memref<f32>
}
// -----
func @collapse_to_higher_rank(%arg0: memref<f32>) {
// expected-error @+1 {{expected the type 'memref<f32>' to have higher rank than the type = 'memref<1xf32>'}}
%0 = linalg.collapse_shape %arg0 [[0]] : memref<f32> into memref<1xf32>
}
// -----
func @expand_to_smaller_rank(%arg0: memref<1xf32>) {
// expected-error @+1 {{expected the type 'memref<f32>' to have higher rank than the type = 'memref<1xf32>'}}
%0 = linalg.expand_shape %arg0 [[0]] : memref<1xf32> into memref<f32>
}
// -----
func @reshape(%arg0: memref<?xf32>) {
// expected-error @+1 {{expected to collapse or expand dims}}
%0 = linalg.reshape %arg0 [[0]] : memref<?xf32> into memref<?xf32>
%0 = linalg.collapse_shape %arg0 [[0]] : memref<?xf32> into memref<?xf32>
}
// -----
func @reshape(%arg0: memref<?x?x?xf32>) {
// expected-error @+1 {{expected rank of the collapsed type(2) to be the number of reassociation maps(1)}}
%0 = linalg.reshape %arg0 [[0, 1]] :
%0 = linalg.collapse_shape %arg0 [[0, 1]] :
memref<?x?x?xf32> into memref<?x?xf32, offset: 0, strides: [?, 1]>
}
@ -370,7 +384,7 @@ func @reshape(%arg0: memref<?x?x?xf32>) {
func @reshape(%arg0: memref<?x?x?xf32>) {
// expected-error @+1 {{expected reassociation map #1 to be valid and contiguous}}
%0 = linalg.reshape %arg0 [[0, 1], [1, 2]] :
%0 = linalg.collapse_shape %arg0 [[0, 1], [1, 2]] :
memref<?x?x?xf32> into memref<?x?xf32, offset: 0, strides: [?, 1]>
}
@ -378,7 +392,7 @@ func @reshape(%arg0: memref<?x?x?xf32>) {
func @reshape(%arg0: memref<?x?x?xf32>) {
// expected-error @+1 {{expected collapsed type to be 'memref<?x?xf32>', but got 'memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>>'}}
%0 = linalg.reshape %arg0 [[0, 1], [2]] :
%0 = linalg.collapse_shape %arg0 [[0, 1], [2]] :
memref<?x?x?xf32> into memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>>
}
@ -455,7 +469,7 @@ func @illegal_expanding_reshape_dynamic_tensor
(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?x4x?xf32>
{
// expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}}
%0 = linalg.tensor_reshape %arg0 [[0], [1], [2, 3, 4]]
%0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3, 4]]
: tensor<?x?x?xf32> into tensor<?x?x?x4x?xf32>
return %0 : tensor<?x?x?x4x?xf32>
}
@ -466,7 +480,7 @@ func @illegal_expanding_reshape_dynamic_memref
(%arg0: memref<?x?x?xf32>) -> memref<?x?x?x4x?xf32>
{
// expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}}
%0 = linalg.reshape %arg0 [[0], [1], [2, 3, 4]]
%0 = linalg.expand_shape %arg0 [[0], [1], [2, 3, 4]]
: memref<?x?x?xf32> into memref<?x?x?x4x?xf32>
return %0 : memref<?x?x?x4x?xf32>
}
@ -477,7 +491,7 @@ func @illegal_expanding_reshape_static_tensor
(%arg0: tensor<2x3x20xf32>) -> tensor<2x3x2x4x5xf32>
{
// expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
%0 = linalg.tensor_reshape %arg0 [[0], [1], [2, 3, 4]]
%0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3, 4]]
: tensor<2x3x20xf32> into tensor<2x3x2x4x5xf32>
return %0 : tensor<2x3x2x4x5xf32>
}
@ -488,7 +502,7 @@ func @illegal_collapsing_reshape_static_tensor
(%arg0: tensor<2x3x2x4x5xf32>) -> tensor<2x3x20xf32>
{
// expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
%0 = linalg.tensor_reshape %arg0 [[0], [1], [2, 3, 4]]
%0 = linalg.tensor_collapse_shape %arg0 [[0], [1], [2, 3, 4]]
: tensor<2x3x2x4x5xf32> into tensor<2x3x20xf32>
return %0 : tensor<2x3x20xf32>
}
@ -499,7 +513,7 @@ func @illegal_expanding_reshape_static_memref
(%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32>
{
// expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
%0 = linalg.reshape %arg0 [[0], [1], [2, 3, 4]]
%0 = linalg.expand_shape %arg0 [[0], [1], [2, 3, 4]]
: memref<2x3x20xf32> into memref<2x3x2x4x5xf32>
return %0 : memref<2x3x2x4x5xf32>
}
@ -510,87 +524,87 @@ func @illegal_collapsing_reshape_static_memref
(%arg0: memref<2x3x2x4x5xf32>) -> memref<2x3x20xf32>
{
// expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
%0 = linalg.reshape %arg0 [[0], [1], [2, 3, 4]]
%0 = linalg.collapse_shape %arg0 [[0], [1], [2, 3, 4]]
: memref<2x3x2x4x5xf32> into memref<2x3x20xf32>
return %0 : memref<2x3x20xf32>
}
// -----
func @illegal_collapsing_reshape_mixed_tensor(%arg0 : tensor<?x?xf32>) -> tensor<?x4x5xf32>
func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor<?x?xf32>) -> tensor<?x4x5xf32>
{
// expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
%0 = linalg.tensor_reshape %arg0 [[0, 1], [2]]
%0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]]
: tensor<?x?xf32> into tensor<?x4x5xf32>
return %0 : tensor<?x4x5xf32>
}
// -----
func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x?xf32>) -> tensor<?x4x5xf32>
func @illegal_expanding_reshape_mixed_tensor_2(%arg0 : tensor<?x?xf32>) -> tensor<?x4x5xf32>
{
// expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
%0 = linalg.tensor_reshape %arg0 [[0], [1, 2]]
%0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2]]
: tensor<?x?xf32> into tensor<?x4x5xf32>
return %0 : tensor<?x4x5xf32>
}
// -----
func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor<?x4x5xf32>) -> tensor<?x?xf32>
func @illegal_collapsing_reshape_mixed_tensor(%arg0 : tensor<?x4x5xf32>) -> tensor<?x?xf32>
{
// expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
%0 = linalg.tensor_reshape %arg0 [[0, 1], [2]]
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2]]
: tensor<?x4x5xf32> into tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
func @illegal_expanding_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>) -> tensor<?x?xf32>
func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>) -> tensor<?x?xf32>
{
// expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
%0 = linalg.tensor_reshape %arg0 [[0], [1, 2]]
%0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2]]
: tensor<?x4x5xf32> into tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
func @illegal_collapsing_reshape_mixed_memref(%arg0 : memref<?x?xf32>) -> memref<?x4x5xf32>
func @illegal_expanding_reshape_mixed_memref(%arg0 : memref<?x?xf32>) -> memref<?x4x5xf32>
{
// expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
%0 = linalg.reshape %arg0 [[0, 1], [2]]
%0 = linalg.expand_shape %arg0 [[0, 1], [2]]
: memref<?x?xf32> into memref<?x4x5xf32>
return %0 : memref<?x4x5xf32>
}
// -----
func @illegal_collapsing_reshape_mixed_memref_2(%arg0 : memref<?x?xf32>) -> memref<?x4x5xf32>
func @illegal_expanding_reshape_mixed_memref_2(%arg0 : memref<?x?xf32>) -> memref<?x4x5xf32>
{
// expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
%0 = linalg.reshape %arg0 [[0], [1, 2]]
%0 = linalg.expand_shape %arg0 [[0], [1, 2]]
: memref<?x?xf32> into memref<?x4x5xf32>
return %0 : memref<?x4x5xf32>
}
// -----
func @illegal_expanding_reshape_mixed_memref(%arg0 : memref<?x4x5xf32>) -> memref<?x?xf32>
func @illegal_collapsing_reshape_mixed_memref(%arg0 : memref<?x4x5xf32>) -> memref<?x?xf32>
{
// expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
%0 = linalg.reshape %arg0 [[0, 1], [2]]
%0 = linalg.collapse_shape %arg0 [[0, 1], [2]]
: memref<?x4x5xf32> into memref<?x?xf32>
return %0 : memref<?x?xf32>
}
// -----
func @illegal_expanding_reshape_mixed_memref_2(%arg0 : memref<?x4x5xf32>) -> memref<?x?xf32>
func @illegal_collapse_reshape_mixed_memref_2(%arg0 : memref<?x4x5xf32>) -> memref<?x?xf32>
{
// expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
%0 = linalg.reshape %arg0 [[0], [1, 2]]
%0 = linalg.collapse_shape %arg0 [[0], [1, 2]]
: memref<?x4x5xf32> into memref<?x?xf32>
return %0 : memref<?x?xf32>
}

View File

@ -14,13 +14,13 @@ func @range(%arg0: index) {
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(i64, i64, i64)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(i64, i64, i64)>
func @reshape_static_expand(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
// Reshapes that expand a contiguous tensor with some 1's.
%0 = linalg.reshape %arg0 [[0, 1], [2], [3, 4]]
%0 = linalg.expand_shape %arg0 [[0, 1], [2], [3, 4]]
: memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
return %0 : memref<1x3x4x1x5xf32>
}
// CHECK-LABEL: func @reshape_static_expand
// CHECK-LABEL: func @expand_shape_static
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
@ -49,12 +49,12 @@ func @reshape_static_expand(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
// CHECK: llvm.mlir.constant(1 : index) : i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
func @reshape_static_collapse(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
%0 = linalg.reshape %arg0 [[0, 1], [2], [3, 4]] :
func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
%0 = linalg.collapse_shape %arg0 [[0, 1], [2], [3, 4]] :
memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
return %0 : memref<3x4x5xf32>
}
// CHECK-LABEL: func @reshape_static_collapse
// CHECK-LABEL: func @collapse_shape_static
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
@ -75,11 +75,11 @@ func @reshape_static_collapse(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32>
// CHECK: llvm.mlir.constant(1 : index) : i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
func @reshape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32> {
%0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32> {
%0 = linalg.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
return %0 : memref<f32>
}
// CHECK-LABEL: func @reshape_fold_zero_dim
// CHECK-LABEL: func @collapse_shape_fold_zero_dim
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
@ -88,11 +88,11 @@ func @reshape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32> {
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
func @reshape_expand_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
%0 = linalg.reshape %arg0 [] : memref<f32> into memref<1x1xf32>
func @expand_shape_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
%0 = linalg.expand_shape %arg0 [] : memref<f32> into memref<1x1xf32>
return %0 : memref<1x1xf32>
}
// CHECK-LABEL: func @reshape_expand_zero_dim
// CHECK-LABEL: func @expand_shape_zero_dim
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>

View File

@ -6,7 +6,7 @@ func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
%arg1 : tensor<?x?x?xf32>) ->
tensor<?x?x?xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3]] :
%0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3]] :
tensor<?x?x4x?xf32> into tensor<?x?x?xf32>
%1 = linalg.generic {
indexing_maps = [#map0, #map1, #map1],
@ -25,16 +25,16 @@ func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
// CHECK: func @generic_op_reshape_producer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK: %[[T0:.+]] = linalg.tensor_collapse_shape %[[ARG0]]
// CHECK-SAME: [0], [1, 2], [3]
// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
// CHECK: %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]]
// CHECK-SAME: [0], [1], [2, 3]
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP6]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[ARG0]], %[[T1]] : tensor<?x?x4x?xf32>, tensor<?x?x?x4xf32>)
// CHECK-SAME: outs(%{{.+}} : tensor<?x?x?x4xf32>)
// CHECK: %[[T4:.+]] = linalg.tensor_reshape %[[T3]]
// CHECK: %[[T4:.+]] = linalg.tensor_collapse_shape %[[T3]]
// CHECK-SAME: [0], [1], [2, 3]
// CHECK-SAME: tensor<?x?x?x4xf32> into tensor<?x?x?xf32>
// CHECK: return %[[T4]]
@ -55,19 +55,19 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
%1 = mulf %arg3, %arg4 : f32
linalg.yield %1 : f32
} -> tensor<?x?xf32>
%1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] :
%1 = linalg.tensor_expand_shape %0 [[0], [1, 2, 3]] :
tensor<?x?xf32> into tensor<?x4x?x5xf32>
return %1 : tensor<?x4x?x5xf32>
}
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func @generic_op_reshape_consumer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>)
// CHECK: %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]]
// CHECK-SAME: [0], [1, 2, 3]
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x?x5xf32>
// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
// CHECK: %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]]
// CHECK-SAME: [0], [1, 2, 3]
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x?x5xf32>
// CHECK: %[[T3:.+]] = linalg.generic
@ -94,7 +94,7 @@ func @reshape_as_consumer_permutation
%1 = addf %arg0, %arg1 : f32
linalg.yield %1 : f32
} -> tensor<?x?x?xf32>
%d = linalg.tensor_reshape %c [[0, 1], [2], [3, 4, 5]]
%d = linalg.tensor_expand_shape %c [[0, 1], [2], [3, 4, 5]]
: tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
return %d : tensor<?x2x?x3x4x?xf32>
}
@ -104,10 +104,10 @@ func @reshape_as_consumer_permutation
// CHECK: func @reshape_as_consumer_permutation
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK: %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]]
// CHECK-SAME: [0, 1, 2], [3, 4], [5]
// CHECK-SAME: tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
// CHECK: %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]]
// CHECK-SAME: [0, 1, 2], [3]
// CHECK-SAME: tensor<?x?xf32> into tensor<3x4x?x?xf32>
// CHECK: %[[T3:.+]] = linalg.generic
@ -136,7 +136,7 @@ func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
%2 = mulf %arg1, %arg2 : f32
linalg.yield %2 : f32
} -> tensor<264x4xf32>
%2 = linalg.tensor_reshape %1 [[0, 1], [2]] :
%2 = linalg.tensor_expand_shape %1 [[0, 1], [2]] :
tensor<264x4xf32> into tensor<8x33x4xf32>
return %2 : tensor<8x33x4xf32>
}
@ -144,7 +144,7 @@ func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: func @generic_op_reshape_consumer_static
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32>
// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK: %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]]
// CHECK-SAME: [0, 1], [2]
// CHECK-SAME: tensor<264x4xf32> into tensor<8x33x4xf32>
// CHECK: %[[T1:.+]] = linalg.init_tensor [8, 33, 4]
@ -163,7 +163,7 @@ func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
%arg1 : tensor<?x?x?xi32>) ->
tensor<?x?x?xi32>
{
%0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3]]:
%0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3]]:
tensor<?x?x4x?xi32> into tensor<?x?x?xi32>
%1 = linalg.generic {
indexing_maps = [#map0, #map1, #map1],
@ -229,7 +229,7 @@ func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
%5 = addi %3, %4 : i32
linalg.yield %5 : i32
} -> tensor<?x?xi32>
%1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] :
%1 = linalg.tensor_expand_shape %0 [[0], [1, 2, 3]] :
tensor<?x?xi32> into tensor<?x?x4x5xi32>
return %1 : tensor<?x?x4x5xi32>
}
@ -279,7 +279,7 @@ func @reshape_as_consumer_permutation
%7 = addi %5, %6 : i32
linalg.yield %7 : i32
} -> tensor<6x4x210xi32>
%d = linalg.tensor_reshape %c [[0, 1], [2], [3, 4, 5]]
%d = linalg.tensor_expand_shape %c [[0, 1], [2], [3, 4, 5]]
: tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
return %d : tensor<2x3x4x5x6x7xi32>
}
@ -293,9 +293,9 @@ func @reshape_as_consumer_permutation
// CHECK: func @reshape_as_consumer_permutation
// CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32>
// CHECK-DAG: %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK-DAG: %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG0]]
// CHECK-SAME: [0, 1, 2], [3, 4], [5]
// CHECK-DAG: %[[T2:.+]] = linalg.tensor_reshape %[[ARG1]]
// CHECK-DAG: %[[T2:.+]] = linalg.tensor_expand_shape %[[ARG1]]
// CHECK-SAME: [0, 1, 2], [3]
// CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7]
// CHECK: %[[T4:.+]] = linalg.generic
@ -326,7 +326,7 @@ func @reshape_as_consumer_permutation
func @reshape_as_producer_projected_permutation(
%arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32>
{
%0 = linalg.tensor_reshape %arg0 [[0, 1], [2]]
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2]]
: tensor<33x8x?xi32> into tensor<264x?xi32>
%1 = linalg.generic
{indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>,
@ -372,7 +372,7 @@ func @reshape_as_producer_projected_permutation(
// CHECK: %[[T5:.+]] = index_cast %[[IDX3]] : index to i32
// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]] : i32
// CHECK: linalg.yield %[[T6]] : i32
// CHECK: %[[RES2:.+]] = linalg.tensor_reshape %[[RES]]
// CHECK: %[[RES2:.+]] = linalg.tensor_collapse_shape %[[RES]]
// CHECK-SAME: [0, 1], [2], [3]
// CHECK-SAME: : tensor<33x8x?x4xi32> into tensor<264x?x4xi32>
// CHECK: return %[[RES2]] : tensor<264x?x4xi32>
@ -394,7 +394,7 @@ func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
%1 = mulf %arg3, %arg4 : f32
linalg.yield %1 : f32
} -> tensor<?x?xf32>
%1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] :
%1 = linalg.tensor_expand_shape %0 [[0], [1, 2, 3]] :
tensor<?x?xf32> into tensor<?x?x4x5xf32>
return %1 : tensor<?x?x4x5xf32>
}
@ -404,10 +404,10 @@ func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
// CHECK: func @generic_op_reshape_consumer_fusion_projected
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK: %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]]
// CHECK-SAME: [0, 1, 2], [3]
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x5x?xf32>
// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
// CHECK: %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]]
// CHECK-SAME: [0, 1, 2], [3]
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x5x?xf32>
// CHECK: %[[T3:.+]] = linalg.generic
@ -420,7 +420,7 @@ func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
// -----
func @unit_dim_reshape_expansion(%arg0 : tensor<1x5xf32>) -> tensor<5x5xf32> {
%0 = linalg.tensor_reshape %arg0 [[0, 1]]
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1]]
: tensor<1x5xf32> into tensor<5xf32>
%1 = linalg.init_tensor [5, 5] : tensor<5x5xf32>
%2 = linalg.generic
@ -434,7 +434,7 @@ func @unit_dim_reshape_expansion(%arg0 : tensor<1x5xf32>) -> tensor<5x5xf32> {
return %2 : tensor<5x5xf32>
}
// CHECK: func @unit_dim_reshape_expansion
// CHECK-DAG: linalg.tensor_reshape
// CHECK-DAG: linalg.tensor_collapse_shape
// CHECK-DAG: linalg.init_tensor
// CHECK: linalg.generic
@ -450,14 +450,14 @@ func @unit_dim_reshape_collapse(%arg0 : tensor<5xf32>) -> tensor<5x1x5xf32> {
^bb0(%arg2: f32, %arg3: f32): // no predecessors
linalg.yield %arg2 : f32
} -> tensor<5x5xf32>
%2 = linalg.tensor_reshape %1 [[0, 1], [2]]
%2 = linalg.tensor_expand_shape %1 [[0, 1], [2]]
: tensor<5x5xf32> into tensor<5x1x5xf32>
return %2 : tensor<5x1x5xf32>
}
// CHECK: func @unit_dim_reshape_collapse
// CHECK: linalg.init_tensor
// CHECK: linalg.generic
// CHECK: linalg.tensor_reshape
// CHECK: linalg.tensor_expand_shape
// -----
@ -465,7 +465,7 @@ func @unit_dim_reshape_expansion_full
(%arg0 : tensor<1x?x1x2x1x4xf32>, %arg1 : tensor<?x2x4xf32>)
-> tensor<?x2x4xf32> {
%c1 = constant 1 : index
%0 = linalg.tensor_reshape %arg0 [[0, 1, 2], [3, 4], [5]]
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]]
: tensor<1x?x1x2x1x4xf32> into tensor<?x2x4xf32>
%1 = memref.dim %arg0, %c1 : tensor<1x?x1x2x1x4xf32>
%2 = linalg.init_tensor [%1, 2, 4] : tensor<?x2x4xf32>
@ -483,7 +483,7 @@ func @unit_dim_reshape_expansion_full
return %3 : tensor<?x2x4xf32>
}
// CHECK: func @unit_dim_reshape_expansion_full
// CHECK-DAG: linalg.tensor_reshape
// CHECK-DAG: linalg.tensor_collapse_shape
// CHECK-DAG: linalg.init_tensor
// CHECK: linalg.generic
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<?x2x4xf32>, tensor<?x2x4xf32>)
@ -491,7 +491,7 @@ func @unit_dim_reshape_expansion_full
// FOLDUNITDIM: func @unit_dim_reshape_expansion_full
// FOLDUNITDIM-SAME: %[[ARG0:.+]]: tensor<1x?x1x2x1x4xf32>
// FOLDUNITDIM-SAME: %[[ARG1:.+]]: tensor<?x2x4xf32>
// FOLDUNITDIM-DAG: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG1]]
// FOLDUNITDIM-DAG: %[[RESHAPE:.+]] = linalg.tensor_expand_shape %[[ARG1]]
// FOLDUNITDIM: linalg.generic
// FOLDUNITDIM-SAME: ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>)
// FOLDUNITDIM-SAME: outs(%{{.+}} : tensor<1x?x1x2x1x4xf32>)

View File

@ -3,7 +3,7 @@
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
-> tensor<?x?x4x?xi32> {
%0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3]] :
%0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2], [3]] :
tensor<?x?x?xi32> into tensor<?x?x4x?xi32>
%1 = linalg.generic {
indexing_maps = [#map0, #map0],
@ -22,7 +22,7 @@ func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func @generic_op_reshape_producer_fusion
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xi32>
// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK: %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]]
// CHECK-SAME: [0], [1, 2], [3]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]]]
@ -46,7 +46,7 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
%3 = addi %arg6, %2 : i32
linalg.yield %3 : i32
} -> tensor<?x?x4x5xi32>
%1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] :
%1 = linalg.tensor_collapse_shape %0 [[0], [1, 2, 3]] :
tensor<?x?x4x5xi32> into tensor<?x?xi32>
return %1 : tensor<?x?xi32>
}
@ -54,21 +54,21 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
// CHECK: func @generic_op_reshape_consumer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x5xi32>
// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK: %[[T0:.+]] = linalg.tensor_collapse_shape %[[ARG0]]
// CHECK-SAME: [0], [1, 2, 3]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]]
// CHECK-SAME: outs(%[[T0]] : tensor<?x?xi32>)
// CHECK: %[[IDX:.+]] = linalg.index 0 : index
// CHECK-NEXT: %[[IDX_CASTED:.+]] = index_cast %[[IDX]] : index to i32
// CHECK-NOT: linalg.tensor_reshape
// CHECK-NOT: linalg.tensor_collapse_shape
// -----
#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> {
%0 = linalg.tensor_reshape %arg0 [[0], [1, 2]]
%0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2]]
: tensor<3x35xf32> into tensor<3x5x7xf32>
%1 = linalg.init_tensor [3, 7, 5] : tensor<3x7x5xf32>
%2 = linalg.generic
@ -84,7 +84,7 @@ func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf3
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: func @generic_op_021_permultation_reshape_producer_fusion
// CHECK-NOT: linalg.tensor_reshape
// CHECK-NOT: linalg.tensor_expand_shape
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
@ -93,7 +93,7 @@ func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf3
#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
func @generic_op_120_permutation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> {
%0 = linalg.tensor_reshape %arg0 [[0], [1, 2]]
%0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2]]
: tensor<3x35xf32> into tensor<3x5x7xf32>
%1 = linalg.init_tensor [5, 7, 3] : tensor<5x7x3xf32>
%2 = linalg.generic
@ -109,7 +109,7 @@ func @generic_op_120_permutation_reshape_producer_fusion(%arg0 : tensor<3x35xf32
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
// CHECK: func @generic_op_120_permutation_reshape_producer_fusion
// CHECK-NOT: linalg.tensor_reshape
// CHECK-NOT: linalg.tensor_expand_shape
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
@ -120,7 +120,7 @@ func @generic_op_120_permutation_reshape_producer_fusion(%arg0 : tensor<3x35xf32
#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> {
%0 = linalg.tensor_reshape %arg0 [[0], [1, 2]]
%0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2]]
: tensor<3x35xf32> into tensor<3x5x7xf32>
%1 = linalg.init_tensor [5, 3, 7] : tensor<5x3x7xf32>
%2 = linalg.generic
@ -137,7 +137,7 @@ func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf3
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: func @generic_op_102_permultation_reshape_producer_fusion
// CHECK-NOT: linalg.tensor_reshape
// CHECK-NOT: linalg.tensor_expand_shape
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
@ -156,7 +156,7 @@ func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf
^bb0(%arg2: f32, %arg3 : f32): // no predecessors
linalg.yield %arg2 : f32
} -> tensor<5x3x7xf32>
%2 = linalg.tensor_reshape %1 [[0], [1, 2]]
%2 = linalg.tensor_collapse_shape %1 [[0], [1, 2]]
: tensor<5x3x7xf32> into tensor<5x21xf32>
return %2 : tensor<5x21xf32>
}
@ -165,7 +165,7 @@ func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf
// CHECK: func @generic_op_102_permultation_reshape_consumer_fusion
// CHECK-SAME: %[[ARG0:.+]]: tensor<3x5x7xf32>
// CHECK: %[[T0:.+]] = linalg.init_tensor [5, 3, 7]
// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[T0]]
// CHECK: %[[T1:.+]] = linalg.tensor_collapse_shape %[[T0]]
// CHECK-SAME: [0], [1, 2]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]]
@ -188,7 +188,7 @@ func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
%1 = mulf %arg3, %arg4 : f32
linalg.yield %1 : f32
} -> tensor<?x?x?x5xf32>
%1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] :
%1 = linalg.tensor_collapse_shape %0 [[0], [1, 2, 3]] :
tensor<?x?x?x5xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
@ -197,5 +197,5 @@ func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?x5xf32>
// CHECK: %[[NOFUSE:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
// CHECK: %[[RESULT:.+]] = linalg.tensor_reshape %[[NOFUSE]]
// CHECK: %[[RESULT:.+]] = linalg.tensor_collapse_shape %[[NOFUSE]]
// CHECK: return %[[RESULT]]

View File

@ -563,92 +563,92 @@ func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?,
func @reshape_static(%arg0: memref<3x4x5xf32>, %arg1: tensor<3x4x5xf32>,
%arg2: tensor<3x?x5xf32>) {
// Reshapes that collapse and expand back a contiguous buffer.
%0 = linalg.reshape %arg0 [[0, 1], [2]] :
%0 = linalg.collapse_shape %arg0 [[0, 1], [2]] :
memref<3x4x5xf32> into memref<12x5xf32>
%r0 = linalg.reshape %0 [[0, 1], [2]] :
%r0 = linalg.expand_shape %0 [[0, 1], [2]] :
memref<12x5xf32> into memref<3x4x5xf32>
%1 = linalg.reshape %arg0 [[0], [1, 2]] :
%1 = linalg.collapse_shape %arg0 [[0], [1, 2]] :
memref<3x4x5xf32> into memref<3x20xf32>
%r1 = linalg.reshape %1 [[0], [1, 2]] :
%r1 = linalg.expand_shape %1 [[0], [1, 2]] :
memref<3x20xf32> into memref<3x4x5xf32>
%2 = linalg.reshape %arg0 [[0, 1, 2]] :
%2 = linalg.collapse_shape %arg0 [[0, 1, 2]] :
memref<3x4x5xf32> into memref<60xf32>
%r2 = linalg.reshape %2 [[0, 1, 2]] :
%r2 = linalg.expand_shape %2 [[0, 1, 2]] :
memref<60xf32> into memref<3x4x5xf32>
// Reshapes that expand and collapse back a contiguous buffer with some 1's.
%3 = linalg.reshape %arg0 [[0, 1], [2], [3, 4]] :
%3 = linalg.expand_shape %arg0 [[0, 1], [2], [3, 4]] :
memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
%r3 = linalg.reshape %3 [[0, 1], [2], [3, 4]] :
%r3 = linalg.collapse_shape %3 [[0, 1], [2], [3, 4]] :
memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
// Reshapes on tensors.
%t0 = linalg.tensor_reshape %arg1 [[0, 1], [2], [3, 4]] :
%t0 = linalg.tensor_expand_shape %arg1 [[0, 1], [2], [3, 4]] :
tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
%rt0 = linalg.tensor_reshape %t0 [[0, 1], [2], [3, 4]] :
%rt0 = linalg.tensor_collapse_shape %t0 [[0, 1], [2], [3, 4]] :
tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
%t1 = linalg.tensor_reshape %arg2 [[0, 1], [2], [3, 4]] :
%t1 = linalg.tensor_expand_shape %arg2 [[0, 1], [2], [3, 4]] :
tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
%rt1 = linalg.tensor_reshape %t1 [[0], [1, 2], [3, 4]] :
%rt1 = linalg.tensor_collapse_shape %t1 [[0], [1, 2], [3, 4]] :
tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
return
}
// CHECK-LABEL: func @reshape_static
// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2]]
// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: memref<3x4x5xf32> into memref<12x5xf32>
// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2]]
// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: memref<12x5xf32> into memref<3x4x5xf32>
// CHECK: linalg.reshape {{.*}} {{\[}}[0], [1, 2]]
// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0], [1, 2]]
// CHECK-SAME: memref<3x4x5xf32> into memref<3x20xf32>
// CHECK: linalg.reshape {{.*}} {{\[}}[0], [1, 2]]
// CHECK: linalg.expand_shape {{.*}} {{\[}}[0], [1, 2]]
// CHECK-SAME: memref<3x20xf32> into memref<3x4x5xf32>
// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1, 2]]
// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
// CHECK-SAME: memref<3x4x5xf32> into memref<60xf32>
// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1, 2]]
// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1, 2]]
// CHECK-SAME: memref<60xf32> into memref<3x4x5xf32>
// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
// CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
// CHECK-SAME: memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
//
// CHECK: linalg.tensor_reshape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
// CHECK: linalg.tensor_reshape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
// CHECK: linalg.tensor_reshape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
// CHECK: linalg.tensor_reshape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
// CHECK: linalg.tensor_expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
// CHECK: linalg.tensor_collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
// CHECK: linalg.tensor_expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
// CHECK: linalg.tensor_collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
// -----
func @reshape_dynamic(%arg0: memref<?x?x?xf32>,
%arg1: memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]>,
%arg2: memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]>) {
%0 = linalg.reshape %arg0 [[0, 1], [2]] :
%0 = linalg.collapse_shape %arg0 [[0, 1], [2]] :
memref<?x?x?xf32> into memref<?x?xf32>
%r0 = linalg.reshape %0 [[0, 1], [2]] :
%r0 = linalg.expand_shape %0 [[0, 1], [2]] :
memref<?x?xf32> into memref<?x4x?xf32>
%1 = linalg.reshape %arg1 [[0, 1], [2]] :
%1 = linalg.collapse_shape %arg1 [[0, 1], [2]] :
memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]> into
memref<?x?xf32, offset : 0, strides : [?, 1]>
%r1 = linalg.reshape %1 [[0, 1], [2]] :
%r1 = linalg.expand_shape %1 [[0, 1], [2]] :
memref<?x?xf32, offset : 0, strides : [?, 1]> into
memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]>
%2 = linalg.reshape %arg2 [[0, 1], [2]] :
%2 = linalg.collapse_shape %arg2 [[0, 1], [2]] :
memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]> into
memref<?x?xf32, offset : ?, strides : [?, 1]>
%r2 = linalg.reshape %2 [[0, 1], [2]] :
%r2 = linalg.expand_shape %2 [[0, 1], [2]] :
memref<?x?xf32, offset : ?, strides : [?, 1]> into
memref<?x4x?xf32, offset : ?, strides : [?, ?, 1]>
return
}
// CHECK-LABEL: func @reshape
// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2]]
// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: memref<?x?x?xf32> into memref<?x?xf32>
// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2]]
// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: memref<?x?xf32> into memref<?x4x?xf32>
// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2]]
// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: memref<?x?x?xf32, #[[$strided3DOFF0]]> into memref<?x?xf32, #[[$strided2DOFF0]]>
// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2]]
// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: memref<?x?xf32, #[[$strided2DOFF0]]> into memref<?x4x?xf32, #[[$strided3DOFF0]]>
// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2]]
// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: memref<?x?x?xf32, #[[$strided3D]]> into memref<?x?xf32, #[[$strided2D]]>
// CHECK: linalg.reshape {{.*}} {{\[}}[0, 1], [2]]
// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: memref<?x?xf32, #[[$strided2D]]> into memref<?x4x?xf32, #[[$strided3D]]>
func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>,
@ -679,25 +679,25 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x
func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor<f32>) -> (tensor<f32>, tensor<1x1xf32>)
{
%0 = linalg.tensor_reshape %arg0 [] : tensor<1x1xf32> into tensor<f32>
%1 = linalg.tensor_reshape %0 [] : tensor<f32> into tensor<1x1xf32>
%0 = linalg.tensor_collapse_shape %arg0 [] : tensor<1x1xf32> into tensor<f32>
%1 = linalg.tensor_expand_shape %0 [] : tensor<f32> into tensor<1x1xf32>
return %0, %1 : tensor<f32>, tensor<1x1xf32>
}
// CHECK-LABEL: func @tensor_reshape_zero_dim
// CHECK: linalg.tensor_reshape %{{.*}} [] : tensor<1x1xf32> into tensor<f32>
// CHECK: linalg.tensor_reshape %{{.*}} [] : tensor<f32> into tensor<1x1xf32>
// CHECK: linalg.tensor_collapse_shape %{{.*}} [] : tensor<1x1xf32> into tensor<f32>
// CHECK: linalg.tensor_expand_shape %{{.*}} [] : tensor<f32> into tensor<1x1xf32>
// -----
func @memref_reshape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>) -> (memref<f32>, memref<1x1xf32>)
{
%0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
%1 = linalg.reshape %0 [] : memref<f32> into memref<1x1xf32>
%0 = linalg.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
%1 = linalg.expand_shape %0 [] : memref<f32> into memref<1x1xf32>
return %0, %1 : memref<f32>, memref<1x1xf32>
}
// CHECK-LABEL: func @memref_reshape_zero_dim
// CHECK: linalg.reshape %{{.*}} [] : memref<1x1xf32> into memref<f32>
// CHECK: linalg.reshape %{{.*}} [] : memref<f32> into memref<1x1xf32>
// CHECK: linalg.collapse_shape %{{.*}} [] : memref<1x1xf32> into memref<f32>
// CHECK: linalg.expand_shape %{{.*}} [] : memref<f32> into memref<1x1xf32>
// -----
@ -716,12 +716,12 @@ func @init_tensor(%arg0 : index, %arg1 : index)
func @legal_collapsing_reshape_dynamic_tensor
(%arg0: tensor<?x?x?x4x?xf32>) -> tensor<?x?x?xf32>
{
%0 = linalg.tensor_reshape %arg0 [[0], [1], [2, 3, 4]] :
%0 = linalg.tensor_collapse_shape %arg0 [[0], [1], [2, 3, 4]] :
tensor<?x?x?x4x?xf32> into tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
// CHECK: func @legal_collapsing_reshape_dynamic_tensor
// CHECK: linalg.tensor_reshape
// CHECK: linalg.tensor_collapse_shape
// CHECK-SAME: [0], [1], [2, 3, 4]
// -----
@ -729,12 +729,12 @@ func @legal_collapsing_reshape_dynamic_tensor
func @legal_collapsing_reshape_dynamic_memref
(%arg0: memref<?x?x?x4x?xf32>) -> memref<?x?x?xf32>
{
%0 = linalg.reshape %arg0 [[0], [1], [2, 3, 4]] :
%0 = linalg.collapse_shape %arg0 [[0], [1], [2, 3, 4]] :
memref<?x?x?x4x?xf32> into memref<?x?x?xf32>
return %0 : memref<?x?x?xf32>
}
// CHECK: func @legal_collapsing_reshape_dynamic_memref
// CHECK: linalg.reshape
// CHECK: linalg.collapse_shape
// CHECK-SAME: [0], [1], [2, 3, 4]
// -----