2020-05-15 12:22:21 +08:00
|
|
|
//===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
|
|
|
|
|
|
|
|
#include "../PassDetail.h"
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
|
|
#include "mlir/Dialect/SCF/SCF.h"
|
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::linalg;
|
|
|
|
|
|
|
|
/// Helper function to extract the operand types that are passed to the
|
|
|
|
/// generated CallOp. MemRefTypes have their layout canonicalized since the
|
|
|
|
/// information is not used in signature generation.
|
|
|
|
/// Note that static size information is not modified.
|
|
|
|
template <typename LinalgOp>
|
|
|
|
static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
|
|
|
|
SmallVector<Type, 4> result;
|
|
|
|
result.reserve(op->getNumOperands());
|
|
|
|
for (auto type : op->getOperandTypes()) {
|
|
|
|
// The underlying descriptor type (e.g. LLVM) does not have layout
|
|
|
|
// information. Canonicalizing the type at the level of std when going into
|
|
|
|
// a library call avoids needing to introduce DialectCastOp.
|
|
|
|
if (auto memrefType = type.dyn_cast<MemRefType>())
|
|
|
|
result.push_back(eraseStridedLayout(memrefType));
|
|
|
|
else
|
|
|
|
result.push_back(type);
|
|
|
|
}
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
SmallVector<Type, 4> extractOperandTypes<IndexedGenericOp>(Operation *op) {
|
|
|
|
auto *ctx = op->getContext();
|
|
|
|
auto indexedGenericOp = cast<IndexedGenericOp>(op);
|
|
|
|
auto numLoops = indexedGenericOp.getNumLoops();
|
|
|
|
|
|
|
|
SmallVector<Type, 4> result(numLoops, IndexType::get(ctx));
|
|
|
|
auto canonicalizedOperands = extractOperandTypes<LinalgOp>(op);
|
|
|
|
result.append(canonicalizedOperands.begin(), canonicalizedOperands.end());
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Get a SymbolRefAttr containing the library function name for the LinalgOp.
|
|
|
|
// If the library function does not exist, insert a declaration.
|
|
|
|
template <typename LinalgOp>
|
|
|
|
static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
|
|
|
|
PatternRewriter &rewriter) {
|
|
|
|
auto linalgOp = cast<LinalgOp>(op);
|
|
|
|
auto fnName = linalgOp.getLibraryCallName();
|
|
|
|
if (fnName.empty()) {
|
|
|
|
op->emitWarning("No library call defined for: ") << *op;
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
|
|
|
|
// fnName is a dynamic std::string, unique it via a SymbolRefAttr.
|
|
|
|
FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
|
|
|
|
auto module = op->getParentOfType<ModuleOp>();
|
|
|
|
if (module.lookupSymbol(fnName)) {
|
|
|
|
return fnNameAttr;
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<Type, 4> inputTypes(extractOperandTypes<LinalgOp>(op));
|
|
|
|
assert(op->getNumResults() == 0 &&
|
|
|
|
"Library call for linalg operation can be generated only for ops that "
|
|
|
|
"have void return types");
|
|
|
|
auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext());
|
|
|
|
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
// Insert before module terminator.
|
|
|
|
rewriter.setInsertionPoint(module.getBody(),
|
|
|
|
std::prev(module.getBody()->end()));
|
|
|
|
FuncOp funcOp =
|
2020-07-08 07:15:44 +08:00
|
|
|
rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType);
|
2020-05-15 12:22:21 +08:00
|
|
|
// Insert a function attribute that will trigger the emission of the
|
|
|
|
// corresponding `_mlir_ciface_xxx` interface so that external libraries see
|
|
|
|
// a normalized ABI. This interface is added during std to llvm conversion.
|
|
|
|
funcOp.setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext()));
|
|
|
|
return fnNameAttr;
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
SmallVector<Value, 4>
|
|
|
|
createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
|
|
|
|
ValueRange operands) {
|
|
|
|
SmallVector<Value, 4> res;
|
|
|
|
res.reserve(operands.size());
|
|
|
|
for (auto op : operands) {
|
|
|
|
auto memrefType = op.getType().dyn_cast<MemRefType>();
|
|
|
|
if (!memrefType) {
|
|
|
|
res.push_back(op);
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
Value cast =
|
|
|
|
b.create<MemRefCastOp>(loc, eraseStridedLayout(memrefType), op);
|
|
|
|
res.push_back(cast);
|
|
|
|
}
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
|
|
|
// LinalgOpConversion<LinalgOp> creates a new call to the type-canonicalized
|
|
|
|
// `LinalgOp::getLibraryCallName()` function.
|
|
|
|
// The implementation of the function can be either in the same module or in an
|
|
|
|
// externally linked library.
|
|
|
|
template <typename LinalgOp>
|
|
|
|
class LinalgOpConversion : public OpRewritePattern<LinalgOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern<LinalgOp>::OpRewritePattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(LinalgOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter);
|
|
|
|
if (!libraryCallName)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
2020-09-23 12:00:11 +08:00
|
|
|
op, libraryCallName.getValue(), TypeRange(),
|
2020-05-15 12:22:21 +08:00
|
|
|
createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
|
|
|
|
op.getOperands()));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
/// Conversion pattern specialization for CopyOp. This kicks in when both input
|
|
|
|
/// and output permutations are left unspecified or are the identity.
|
|
|
|
template <>
|
|
|
|
class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern<CopyOp>::OpRewritePattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(CopyOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
auto inputPerm = op.inputPermutation();
|
|
|
|
if (inputPerm.hasValue() && !inputPerm->isIdentity())
|
|
|
|
return failure();
|
|
|
|
auto outputPerm = op.outputPermutation();
|
|
|
|
if (outputPerm.hasValue() && !outputPerm->isIdentity())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter);
|
|
|
|
if (!libraryCallName)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
2020-09-23 12:00:11 +08:00
|
|
|
op, libraryCallName.getValue(), TypeRange(),
|
2020-05-15 12:22:21 +08:00
|
|
|
createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
|
|
|
|
op.getOperands()));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
/// Conversion pattern specialization for IndexedGenericOp.
|
|
|
|
template <>
|
|
|
|
class LinalgOpConversion<IndexedGenericOp>
|
|
|
|
: public OpRewritePattern<IndexedGenericOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(IndexedGenericOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
auto libraryCallName =
|
|
|
|
getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
|
|
|
|
if (!libraryCallName)
|
|
|
|
return failure();
|
|
|
|
|
2020-07-07 16:35:23 +08:00
|
|
|
// TODO: Use induction variables values instead of zeros, when
|
2020-05-15 12:22:21 +08:00
|
|
|
// IndexedGenericOp is tiled.
|
|
|
|
auto zero = rewriter.create<mlir::ConstantOp>(
|
|
|
|
op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
|
|
|
|
auto indexedGenericOp = cast<IndexedGenericOp>(op);
|
|
|
|
auto numLoops = indexedGenericOp.getNumLoops();
|
|
|
|
SmallVector<Value, 4> operands;
|
|
|
|
operands.reserve(numLoops + op.getNumOperands());
|
|
|
|
for (unsigned i = 0; i < numLoops; ++i)
|
|
|
|
operands.push_back(zero);
|
|
|
|
for (auto operand : op.getOperands())
|
|
|
|
operands.push_back(operand);
|
|
|
|
rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
2020-09-23 12:00:11 +08:00
|
|
|
op, libraryCallName.getValue(), TypeRange(),
|
2020-05-15 12:22:21 +08:00
|
|
|
createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
/// A non-conversion rewrite pattern kicks in to convert CopyOp with
|
|
|
|
/// permutations into a sequence of TransposeOp and permutation-free CopyOp.
|
|
|
|
/// This interplays together with TransposeOpConversion and
|
|
|
|
/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
|
|
|
|
class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern<CopyOp>::OpRewritePattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(CopyOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Value in = op.input(), out = op.output();
|
|
|
|
|
|
|
|
// If either inputPerm or outputPerm are non-identities, insert transposes.
|
|
|
|
auto inputPerm = op.inputPermutation();
|
|
|
|
if (inputPerm.hasValue() && !inputPerm->isIdentity())
|
|
|
|
in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in,
|
|
|
|
AffineMapAttr::get(*inputPerm));
|
|
|
|
auto outputPerm = op.outputPermutation();
|
|
|
|
if (outputPerm.hasValue() && !outputPerm->isIdentity())
|
|
|
|
out = rewriter.create<linalg::TransposeOp>(
|
|
|
|
op.getLoc(), out, AffineMapAttr::get(*outputPerm));
|
|
|
|
|
|
|
|
// If nothing was transposed, fail and let the conversion kick in.
|
|
|
|
if (in == op.input() && out == op.output())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
/// Populate the given list with patterns that convert from Linalg to Standard.
|
|
|
|
void mlir::populateLinalgToStandardConversionPatterns(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
2020-07-07 16:35:23 +08:00
|
|
|
// TODO: ConvOp conversion needs to export a descriptor with relevant
|
2020-05-15 12:22:21 +08:00
|
|
|
// attribute values such as kernel striding and dilation.
|
|
|
|
// clang-format off
|
|
|
|
patterns.insert<
|
|
|
|
CopyTransposeConversion,
|
|
|
|
LinalgOpConversion<ConvOp>,
|
|
|
|
LinalgOpConversion<PoolingMaxOp>,
|
|
|
|
LinalgOpConversion<PoolingMinOp>,
|
|
|
|
LinalgOpConversion<PoolingSumOp>,
|
2020-07-29 22:24:48 +08:00
|
|
|
LinalgOpConversion<CopyOp>,
|
2020-05-15 12:22:21 +08:00
|
|
|
LinalgOpConversion<FillOp>,
|
|
|
|
LinalgOpConversion<GenericOp>,
|
2020-06-18 17:32:09 +08:00
|
|
|
LinalgOpConversion<IndexedGenericOp>>(ctx);
|
[mlir][Linalg] Retire C++ MatmulOp in favor of a linalg-ods-gen'd op.
Summary:
This revision replaces MatmulOp, now that DRR rules have been dropped.
This revision also fixes minor parsing bugs and a plugs a few holes to get e2e paths working (e.g. library call emission).
During the replacement the i32 version had to be dropped because only the EDSC operators +, *, etc support type inference.
Deciding on a type-polymorphic behavior, and implementing it, is left for future work.
Reviewers: aartbik
Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, msifontes
Tags: #mlir
Differential Revision: https://reviews.llvm.org/D81935
2020-06-16 21:14:42 +08:00
|
|
|
// TODO: collect all auto-generated named ops with a tblgen directive.
|
|
|
|
patterns.insert<
|
2020-07-28 18:29:54 +08:00
|
|
|
LinalgOpConversion<DotOp>,
|
[mlir][Linalg] Retire C++ MatmulOp in favor of a linalg-ods-gen'd op.
Summary:
This revision replaces MatmulOp, now that DRR rules have been dropped.
This revision also fixes minor parsing bugs and a plugs a few holes to get e2e paths working (e.g. library call emission).
During the replacement the i32 version had to be dropped because only the EDSC operators +, *, etc support type inference.
Deciding on a type-polymorphic behavior, and implementing it, is left for future work.
Reviewers: aartbik
Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, msifontes
Tags: #mlir
Differential Revision: https://reviews.llvm.org/D81935
2020-06-16 21:14:42 +08:00
|
|
|
LinalgOpConversion<BatchMatmulOp>,
|
2020-06-18 17:32:09 +08:00
|
|
|
LinalgOpConversion<MatvecOp>,
|
2020-09-11 00:48:13 +08:00
|
|
|
LinalgOpConversion<VecmatOp>,
|
2020-08-03 15:57:06 +08:00
|
|
|
LinalgOpConversion<MatmulOp>,
|
|
|
|
LinalgOpConversion<ConvWOp>,
|
|
|
|
LinalgOpConversion<ConvNWCOp>,
|
|
|
|
LinalgOpConversion<ConvNCWOp>,
|
|
|
|
LinalgOpConversion<ConvHWOp>,
|
|
|
|
LinalgOpConversion<ConvNHWCOp>,
|
|
|
|
LinalgOpConversion<ConvNCHWOp>,
|
|
|
|
LinalgOpConversion<ConvDHWOp>,
|
|
|
|
LinalgOpConversion<ConvNDHWCOp>,
|
|
|
|
LinalgOpConversion<ConvNCDHWOp>>(ctx);
|
2020-05-15 12:22:21 +08:00
|
|
|
// clang-format on
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
struct ConvertLinalgToStandardPass
|
|
|
|
: public ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
|
|
|
|
void runOnOperation() override;
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
void ConvertLinalgToStandardPass::runOnOperation() {
|
|
|
|
auto module = getOperation();
|
|
|
|
ConversionTarget target(getContext());
|
|
|
|
target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect>();
|
|
|
|
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
|
|
|
|
target.addLegalOp<linalg::TransposeOp, linalg::ReshapeOp, linalg::RangeOp>();
|
|
|
|
OwningRewritePatternList patterns;
|
|
|
|
populateLinalgToStandardConversionPatterns(patterns, &getContext());
|
|
|
|
if (failed(applyFullConversion(module, target, patterns)))
|
|
|
|
signalPassFailure();
|
|
|
|
}
|
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
|
|
mlir::createConvertLinalgToStandardPass() {
|
|
|
|
return std::make_unique<ConvertLinalgToStandardPass>();
|
|
|
|
}
|