forked from OSchip/llvm-project
[mlir][Linalg] Add canonicalization of linalg op -> dim op.
Add canonicalization to replace use of the result of a linalg operation on tensors in a dim operation, to use one of the operands of the linalg operations instead. This allows the linalg op itself to be deleted when all its non-dim uses are removed (say through tiling, etc.) Differential Revision: https://reviews.llvm.org/D93076
This commit is contained in:
parent
547b032ccc
commit
774c9c6ef3
|
@ -32,6 +32,9 @@ def Linalg_Dialect : Dialect {
|
|||
the op semantics.
|
||||
}];
|
||||
let cppNamespace = "::mlir::linalg";
|
||||
let dependentDialects = [
|
||||
"AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
|
||||
];
|
||||
}
|
||||
|
||||
// Whether a type is a RangeType.
|
||||
|
|
|
@ -946,6 +946,56 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
|||
return inversePermutation(getLoopsToShapesMap());
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the position in the results of the affine map computed
|
||||
by getLoopsToShapesMap() that represents the shape of an
|
||||
operand (input or output) at a dimension.
|
||||
}],
|
||||
/*retTy=*/"Optional<unsigned>",
|
||||
/*methodName=*/"getOperandDimPositionInLoopsToShapeMap",
|
||||
/*args=*/(ins "unsigned":$operandIdx, "unsigned":$dim),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
unsigned pos = 0;
|
||||
for (auto type : llvm::enumerate(getShapedOperandTypes())) {
|
||||
if (type.index() == operandIdx) return pos + dim;
|
||||
pos += type.value().getRank();
|
||||
}
|
||||
return {};
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the position in the results of the affine map computed
|
||||
by getLoopsToShapesMap() that represents the shape of an
|
||||
input operand at a dimension.
|
||||
}],
|
||||
/*retTy=*/"Optional<unsigned>",
|
||||
/*methodName=*/"getInputValueDimPositionInLoopsToShapeMap",
|
||||
/*args=*/(ins "unsigned":$inputIdx, "unsigned":$dim),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
if (inputIdx >= getNumInputs()) return {};
|
||||
return getOperandDimPositionInLoopsToShapeMap(inputIdx, dim);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the position in the results of the affine map computed
|
||||
by getLoopsToShapesMap() that represents the shape of the
|
||||
result value at a dimension.
|
||||
}],
|
||||
/*retTy=*/"Optional<unsigned>",
|
||||
/*methodName=*/"getResultValueDimPositionInLoopsToShapeMap",
|
||||
/*args=*/(ins "unsigned":$resultIdx, "unsigned":$dim),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
if (resultIdx >= getNumOutputs()) return {};
|
||||
return getOperandDimPositionInLoopsToShapeMap(
|
||||
getNumInputs() + resultIdx, dim);
|
||||
}]
|
||||
>,
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// Other static interface methods.
|
||||
|
@ -1027,6 +1077,12 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
|||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Returns the value that expresses the shape of the output in terms of
|
||||
/// shape of the input operands where possible
|
||||
Optional<Value> inferResultDimFromInputShapes
|
||||
(OpBuilder &b, Location loc, unsigned resultIdx, unsigned im);
|
||||
|
||||
//========================================================================//
|
||||
// Helper functions to mutate the `operand_segment_sizes` attribute.
|
||||
// These are useful when cloning and changing operand types.
|
||||
|
|
|
@ -9,6 +9,9 @@
|
|||
#ifndef MLIR_DIALECT_LINALG_LINALGTYPES_H_
|
||||
#define MLIR_DIALECT_LINALG_LINALGTYPES_H_
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
|
|
|
@ -159,29 +159,29 @@ public:
|
|||
|
||||
// Default visit methods. Note that the default op-specific binary op visit
|
||||
// methods call the general visitAffineBinaryOpExpr visit method.
|
||||
void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {}
|
||||
void visitAddExpr(AffineBinaryOpExpr expr) {
|
||||
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
|
||||
RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
|
||||
RetTy visitAddExpr(AffineBinaryOpExpr expr) {
|
||||
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
|
||||
}
|
||||
void visitMulExpr(AffineBinaryOpExpr expr) {
|
||||
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
|
||||
RetTy visitMulExpr(AffineBinaryOpExpr expr) {
|
||||
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
|
||||
}
|
||||
void visitModExpr(AffineBinaryOpExpr expr) {
|
||||
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
|
||||
RetTy visitModExpr(AffineBinaryOpExpr expr) {
|
||||
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
|
||||
}
|
||||
void visitFloorDivExpr(AffineBinaryOpExpr expr) {
|
||||
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
|
||||
RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
|
||||
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
|
||||
}
|
||||
void visitCeilDivExpr(AffineBinaryOpExpr expr) {
|
||||
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
|
||||
RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
|
||||
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
|
||||
}
|
||||
void visitConstantExpr(AffineConstantExpr expr) {}
|
||||
void visitDimExpr(AffineDimExpr expr) {}
|
||||
void visitSymbolExpr(AffineSymbolExpr expr) {}
|
||||
RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
|
||||
RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
|
||||
RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
|
||||
|
||||
private:
|
||||
// Walk the operands - each operand is itself walked in post order.
|
||||
void walkOperandsPostOrder(AffineBinaryOpExpr expr) {
|
||||
RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
|
||||
walkPostOrder(expr.getLHS());
|
||||
walkPostOrder(expr.getRHS());
|
||||
}
|
||||
|
|
|
@ -16,12 +16,14 @@
|
|||
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/AffineExprVisitor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
|
@ -86,6 +88,82 @@ SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
|
|||
return res;
|
||||
}
|
||||
|
||||
/// Visitor to check if any of the given set of positions from AffineDimExprs
|
||||
/// are used within an AffineExpr.
|
||||
struct HasAffineDimExprVisitor
|
||||
: public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
|
||||
HasAffineDimExprVisitor(llvm::SmallSet<unsigned, 4> &positions)
|
||||
: positions(positions) {}
|
||||
|
||||
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
|
||||
return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
|
||||
}
|
||||
|
||||
bool visitDimExpr(AffineDimExpr dimExpr) {
|
||||
return positions.count(dimExpr.getPosition());
|
||||
}
|
||||
|
||||
bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
|
||||
|
||||
bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
|
||||
|
||||
private:
|
||||
llvm::SmallSet<unsigned, 4> positions;
|
||||
};
|
||||
|
||||
Optional<Value> LinalgOp::inferResultDimFromInputShapes(OpBuilder &b,
|
||||
Location loc,
|
||||
unsigned resultIdx,
|
||||
unsigned dim) {
|
||||
// An example that helps understand the logic below.
|
||||
// Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
|
||||
// We want to express the shape of dim 0 of O in terms of shape of the inputs.
|
||||
// This is achieved as follows.
|
||||
// loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
|
||||
// subMapOfResultDim = (d0, d1, d2) -> (d0 + d1)
|
||||
// shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
|
||||
// resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap)
|
||||
// = (d0, d1, d2, d3, d4, d5) -> (d0 + d1)
|
||||
AffineMap loopsToShapesMap = getLoopsToShapesMap();
|
||||
|
||||
// Find the position in the above map that represents the shape of the
|
||||
// result:dim being inferred.
|
||||
Optional<unsigned> resultDimSubMapPos =
|
||||
getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim);
|
||||
if (!resultDimSubMapPos)
|
||||
return {};
|
||||
|
||||
/// From loopsToShapesMap extract the submap that represents the shape of the
|
||||
/// (resultIdx, dim) needed
|
||||
AffineMap loopToResultDimShapeMap =
|
||||
loopsToShapesMap.getSubMap(*resultDimSubMapPos);
|
||||
AffineMap operandShapesToResultDimMap =
|
||||
loopToResultDimShapeMap.compose(getShapesToLoopsMap());
|
||||
|
||||
// Check that the result dim map does not contain the positions corresponding
|
||||
// to the outputs.
|
||||
llvm::SmallSet<unsigned, 4> outputDims;
|
||||
unsigned outputDimPosStart =
|
||||
getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue();
|
||||
unsigned outputDimPosEnd =
|
||||
getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1,
|
||||
getOutputOpOperands()
|
||||
.back()
|
||||
.get()
|
||||
.getType()
|
||||
.cast<ShapedType>()
|
||||
.getRank() -
|
||||
1)
|
||||
.getValue();
|
||||
llvm::for_each(llvm::seq<unsigned>(outputDimPosStart, outputDimPosEnd),
|
||||
[&outputDims](unsigned dim) { outputDims.insert(dim); });
|
||||
HasAffineDimExprVisitor checkDimExpr(outputDims);
|
||||
if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0)))
|
||||
return llvm::None;
|
||||
return applyMapToValues(b, loc, operandShapesToResultDimMap,
|
||||
createFlatListOfOperandDims(b, loc))[0];
|
||||
}
|
||||
|
||||
/// Forward declarations.
|
||||
template <typename NamedStructuredOpType>
|
||||
static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
|
||||
|
@ -2022,6 +2100,49 @@ struct FoldTensorCastOp : public RewritePattern {
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Replaces std.dim operations that use the result of a LinalgOp (on tensors)
|
||||
/// with std.dim operations that use one of the arguments. For example,
|
||||
///
|
||||
/// %0 = linalg.matmul ins(%arg0, %arg1, ...)
|
||||
/// %1 = dim %0, %c0
|
||||
///
|
||||
/// with
|
||||
///
|
||||
/// %1 = dim %arg0, %c0
|
||||
///
|
||||
/// where possible. With this the result of the `linalg.matmul` is not used in
|
||||
/// dim operations. If the value produced is replaced with another value (say by
|
||||
/// tiling `linalg.matmul`) will make the `linalg.matmul` truly dead instead of
|
||||
/// used in a dim op that would prevent the DCE of this op.
|
||||
struct ReplaceDimOfLinalgOpResult : public OpRewritePattern<DimOp> {
|
||||
using OpRewritePattern<DimOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(DimOp dimOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value dimValue = dimOp.memrefOrTensor();
|
||||
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
|
||||
if (!dimIndex)
|
||||
return failure();
|
||||
auto linalgOp = dimValue.getDefiningOp<LinalgOp>();
|
||||
if (!linalgOp)
|
||||
return failure();
|
||||
|
||||
unsigned resultIndex = dimValue.cast<OpResult>().getResultNumber();
|
||||
Optional<Value> operandDimValue = linalgOp.inferResultDimFromInputShapes(
|
||||
rewriter, dimOp.getLoc(), resultIndex,
|
||||
static_cast<unsigned>(*dimIndex));
|
||||
if (!operandDimValue) {
|
||||
// Its always possible to replace using the corresponding `outs`
|
||||
// parameter.
|
||||
operandDimValue = rewriter.create<DimOp>(
|
||||
dimOp.getLoc(), linalgOp.getOutput(resultIndex), *dimIndex);
|
||||
}
|
||||
rewriter.replaceOp(dimOp, *operandDimValue);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
@ -2166,26 +2287,6 @@ struct RemoveIdentityLinalgOps : public RewritePattern {
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Canonicalize a `linalgOp` -> `dim` pattern by replacing the `dim` arg
|
||||
/// with the corresponding output tensor argument of the linalg op.
|
||||
struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
|
||||
using OpRewritePattern<DimOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(DimOp dimOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value dimOpArg = dimOp.memrefOrTensor();
|
||||
auto linalgOp = dimOpArg.getDefiningOp<LinalgOp>();
|
||||
if (!linalgOp)
|
||||
return failure();
|
||||
|
||||
auto results = linalgOp.getOperation()->getResults();
|
||||
int64_t id = std::distance(results.begin(), llvm::find(results, dimOpArg));
|
||||
auto outputTensors = linalgOp.getOutputTensors();
|
||||
rewriter.replaceOpWithNewOp<DimOp>(dimOp, outputTensors[id], dimOp.index());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
#define CANONICALIZERS_AND_FOLDERS(XXX) \
|
||||
|
@ -2193,7 +2294,7 @@ struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
|
|||
MLIRContext *context) { \
|
||||
results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp, \
|
||||
RemoveIdentityLinalgOps>(); \
|
||||
results.insert<ReplaceDimOfLinalgResult>(context); \
|
||||
results.insert<ReplaceDimOfLinalgOpResult>(context); \
|
||||
} \
|
||||
\
|
||||
LogicalResult XXX::fold(ArrayRef<Attribute>, \
|
||||
|
|
|
@ -58,9 +58,6 @@ struct LinalgInlinerInterface : public DialectInlinerInterface {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::linalg::LinalgDialect::initialize() {
|
||||
getContext()->getOrLoadDialect("std");
|
||||
getContext()->getOrLoadDialect("tensor");
|
||||
|
||||
addTypes<RangeType>();
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
|
|
|
@ -390,10 +390,147 @@ func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
|
|||
|
||||
// -----
|
||||
|
||||
func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
|
||||
%1 = dim %0, %c0 : tensor<?x?xf32>
|
||||
%2 = dim %0, %c1 : tensor<?x?xf32>
|
||||
return %1, %2 : index, index
|
||||
}
|
||||
// CHECK: func @init_tensor_dynamic_dim2
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK: return %[[ARG0]], %[[ARG1]]
|
||||
|
||||
// -----
|
||||
|
||||
func @remove_dim_result_uses
|
||||
(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
|
||||
%arg2 : tensor<?x?xf32>) -> (index) {
|
||||
%c0 = constant 0 : index
|
||||
%0 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0 + d1, d1)>],
|
||||
iterator_types = ["parallel", "parallel", "reduction"]}
|
||||
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg2 : tensor<?x?xf32>) {
|
||||
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
|
||||
%1 = mulf %arg3, %arg4 : f32
|
||||
%2 = addf %1, %arg5 : f32
|
||||
linalg.yield %2 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
%3 = dim %0, %c0 : tensor<?x?xf32>
|
||||
return %3 : index
|
||||
}
|
||||
// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
|
||||
// CHECK: func @remove_dim_result_uses
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK-DAG: %[[T0:.+]] = dim %[[ARG0]], %[[C0]]
|
||||
// CHECK-DAG: %[[T1:.+]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[T2:.+]] = affine.apply #[[MAP]]()[%[[T0]], %[[T1]]]
|
||||
// CHECK: return %[[T2]]
|
||||
|
||||
// -----
|
||||
|
||||
func @remove_dim_result_uses_outs
|
||||
(%arg0 : tensor<?xf32>, %arg1 : index) -> (index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%d0 = dim %arg0, %c0 : tensor<?xf32>
|
||||
%0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
|
||||
%1 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1) -> (d0)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
|
||||
^bb0(%arg2: f32, %arg3: f32) :
|
||||
linalg.yield %arg2 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
%2 = dim %1, %c1 : tensor<?x?xf32>
|
||||
return %2 : index
|
||||
}
|
||||
// CHECK: func @remove_dim_result_uses_outs
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK: return %[[ARG1]]
|
||||
|
||||
// -----
|
||||
|
||||
func @remove_dim_result_uses_sequence
|
||||
(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
|
||||
%arg2 : tensor<?x?xf32>) -> (index, index, index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = dim %0, %c0 : tensor<?x?xf32>
|
||||
%2 = dim %0, %c1 : tensor<?x?xf32>
|
||||
%3 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d2)>],
|
||||
iterator_types = ["parallel", "reduction", "parallel"]}
|
||||
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%0 : tensor<?x?xf32>) {
|
||||
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
|
||||
%4 = mulf %arg3, %arg4 : f32
|
||||
%5 = addf %4, %arg5 : f32
|
||||
linalg.yield %5 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
%6 = dim %3, %c0 : tensor<?x?xf32>
|
||||
%7 = dim %3, %c1 : tensor<?x?xf32>
|
||||
return %1, %2, %6, %7 : index, index, index, index
|
||||
}
|
||||
// CHECK-LABEL: func @remove_dim_result_uses_sequence
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK-DAG: %[[T0:.+]] = dim %[[ARG0]], %[[C0]]
|
||||
// CHECK-DAG: %[[T1:.+]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK-DAG: %[[T2:.+]] = dim %[[ARG0]], %[[C1]]
|
||||
// CHECK-DAG: %[[T3:.+]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: return %[[T0]], %[[T1]], %[[T2]], %[[T3]]
|
||||
|
||||
// -----
|
||||
|
||||
func @keep_result_dim_uses_sequence2
|
||||
(%arg0 : tensor<?xf32>, %arg1 : index) -> (index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%d0 = dim %arg0, %c0 : tensor<?xf32>
|
||||
%0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
|
||||
%1 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1) -> (d0)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
|
||||
^bb0(%arg2: f32, %arg3 : f32):
|
||||
linalg.yield %arg2 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
%2 = dim %1, %c0 : tensor<?x?xf32>
|
||||
%3 = dim %1, %c1 : tensor<?x?xf32>
|
||||
return %2, %3 : index, index
|
||||
}
|
||||
// CHECK: func @keep_result_dim_uses_sequence2
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK-DAG: %[[T0:.+]] = dim %[[ARG0]], %[[C0]]
|
||||
// CHECK: return %[[T0]], %[[ARG1]]
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0) -> (d0)>
|
||||
|
||||
func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
|
||||
%arg_1: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
|
||||
%arg_1: tensor<?xf32>) -> (index, index) {
|
||||
%0, %1 = linalg.generic {
|
||||
indexing_maps = [#map, #map, #map],
|
||||
iterator_types = ["parallel"]
|
||||
|
@ -405,16 +542,16 @@ func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
|
|||
|
||||
%c0 = constant 0 : index
|
||||
%num_elem_0 = dim %0, %c0 : tensor<?xf32>
|
||||
%result_0 = linalg.init_tensor [%num_elem_0] : tensor<?xf32>
|
||||
|
||||
%num_elem_1 = dim %1, %c0 : tensor<?xf32>
|
||||
%result_1 = linalg.init_tensor [%num_elem_1] : tensor<?xf32>
|
||||
return %result_0, %result_1 : tensor<?xf32>, tensor<?xf32>
|
||||
return %num_elem_0, %num_elem_1 : index, index
|
||||
}
|
||||
// CHECK-LABEL: func @init_tensor_dim_of_linalg_result(
|
||||
// CHECK-SAME: [[ARG_0:%.*]]: tensor<?xf32>, [[ARG_1:%.*]]: tensor<?xf32>)
|
||||
// CHECK: dim [[ARG_0]]
|
||||
// CHECK: dim [[ARG_1]]
|
||||
// CHECK: func @init_tensor_dim_of_linalg_result(
|
||||
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
|
||||
// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: tensor<?xf32>)
|
||||
// CHECK: %[[R0:.+]] = dim %[[ARG_0]]
|
||||
// CHECK: %[[R1:.+]] = dim %[[ARG_0]]
|
||||
// CHECK: return %[[R0]], %[[R1]]
|
||||
|
||||
// -----
|
||||
|
||||
|
|
Loading…
Reference in New Issue