forked from OSchip/llvm-project
[mlir][linalg] Remove IndexedGenericOp support from LinalgToStandard...
after introducing the IndexedGenericOp to GenericOp canonicalization (https://reviews.llvm.org/D101612). Differential Revision: https://reviews.llvm.org/D102236
This commit is contained in:
parent
96100f1508
commit
0fb364a97e
|
@ -28,8 +28,8 @@ namespace linalg {
|
|||
// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
|
||||
// function. The implementation of the function can be either in the same module
|
||||
// or in an externally linked library.
|
||||
// This is a generic entry point for all LinalgOp, except for CopyOp and
|
||||
// IndexedGenericOp, for which more specialized patterns are provided.
|
||||
// This is a generic entry point for all LinalgOp, except for CopyOp, for which
|
||||
// more specialized patterns are provided.
|
||||
class LinalgOpToLibraryCallRewrite
|
||||
: public OpInterfaceRewritePattern<LinalgOp> {
|
||||
public:
|
||||
|
@ -58,16 +58,6 @@ public:
|
|||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Conversion pattern specialization for IndexedGenericOp, has special handling
|
||||
/// for the extra index operands.
|
||||
class IndexedGenericOpToLibraryCallRewrite
|
||||
: public OpRewritePattern<IndexedGenericOp> {
|
||||
public:
|
||||
using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(IndexedGenericOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Populate the given list with patterns that convert from Linalg to Standard.
|
||||
void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns);
|
||||
|
||||
|
|
|
@ -26,12 +26,6 @@ using namespace mlir::linalg;
|
|||
static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
|
||||
SmallVector<Type, 4> result;
|
||||
result.reserve(op->getNumOperands());
|
||||
if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op)) {
|
||||
auto *ctx = op->getContext();
|
||||
auto numLoops = indexedGenericOp.getNumLoops();
|
||||
result.reserve(op->getNumOperands() + numLoops);
|
||||
result.assign(numLoops, IndexType::get(ctx));
|
||||
}
|
||||
for (auto type : op->getOperandTypes()) {
|
||||
// The underlying descriptor type (e.g. LLVM) does not have layout
|
||||
// information. Canonicalizing the type at the level of std when going into
|
||||
|
@ -103,7 +97,11 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
|
|||
LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
|
||||
LinalgOp op, PatternRewriter &rewriter) const {
|
||||
// Only LinalgOp for which there is no specialized pattern go through this.
|
||||
if (isa<CopyOp>(op) || isa<IndexedGenericOp>(op))
|
||||
if (isa<CopyOp>(op))
|
||||
return failure();
|
||||
|
||||
// Canonicalize indexed generic operations before library call conversion.
|
||||
if (isa<IndexedGenericOp>(op))
|
||||
return failure();
|
||||
|
||||
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
|
||||
|
@ -167,31 +165,6 @@ LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
mlir::linalg::IndexedGenericOpToLibraryCallRewrite::matchAndRewrite(
|
||||
IndexedGenericOp op, PatternRewriter &rewriter) const {
|
||||
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
|
||||
if (!libraryCallName)
|
||||
return failure();
|
||||
|
||||
// TODO: Use induction variables values instead of zeros, when
|
||||
// IndexedGenericOp is tiled.
|
||||
auto zero = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
|
||||
auto indexedGenericOp = cast<IndexedGenericOp>(op);
|
||||
auto numLoops = indexedGenericOp.getNumLoops();
|
||||
SmallVector<Value, 4> operands;
|
||||
operands.reserve(numLoops + op.getNumOperands());
|
||||
for (unsigned i = 0; i < numLoops; ++i)
|
||||
operands.push_back(zero);
|
||||
for (auto operand : op.getOperands())
|
||||
operands.push_back(operand);
|
||||
rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
op, libraryCallName.getValue(), TypeRange(),
|
||||
createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands));
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Populate the given list with patterns that convert from Linalg to Standard.
|
||||
void mlir::linalg::populateLinalgToStandardConversionPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
|
@ -201,7 +174,6 @@ void mlir::linalg::populateLinalgToStandardConversionPatterns(
|
|||
patterns.add<
|
||||
CopyOpToLibraryCallRewrite,
|
||||
CopyTransposeRewrite,
|
||||
IndexedGenericOpToLibraryCallRewrite,
|
||||
LinalgOpToLibraryCallRewrite>(patterns.getContext());
|
||||
// clang-format on
|
||||
}
|
||||
|
|
|
@ -95,25 +95,3 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
|
|||
}
|
||||
// CHECK-LABEL: func @matmul_vec_impl(
|
||||
// CHECK: call @external_outerproduct_matmul(%{{.*}}) :
|
||||
|
||||
#indexed_matmul_trait = {
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
indexing_maps = #matmul_accesses,
|
||||
library_call = "external_indexed_outerproduct_matmul"
|
||||
}
|
||||
func @matmul_vec_indexed(%A: !matrix_type_A,
|
||||
%B: !matrix_type_B,
|
||||
%C: !matrix_type_C) {
|
||||
linalg.indexed_generic #indexed_matmul_trait
|
||||
ins(%A, %B : !matrix_type_A, !matrix_type_B)
|
||||
outs(%C : !matrix_type_C) {
|
||||
^bb0(%i: index, %j: index, %k: index,
|
||||
%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
|
||||
%d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
|
||||
linalg.yield %d: !vector_type_C
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @matmul_vec_indexed(
|
||||
// CHECK: %[[ZERO:.*]] = constant 0 : index
|
||||
// CHECK: call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}})
|
||||
|
|
Loading…
Reference in New Issue