[mlir] Remove OperandAdaptor

Use ::Adaptor alias instead uniformly. Makes the naming more consistent as
adaptor can refer to attributes now too.

Differential Revision: https://reviews.llvm.org/D81789
This commit is contained in:
Jacques Pienaar 2020-06-15 06:01:31 -07:00
parent dae9554b2b
commit 2d2c73c5cf
21 changed files with 89 additions and 109 deletions

View File

@ -848,9 +848,8 @@ to access them. For example, for a binary arithmetic operation, it may provide
`.lhs()` to access the first operand and `.rhs()` to access the second operand.
The operand adaptor class lives in the same namespace as the operation class,
and has the name of the operation followed by `OperandAdaptor`. A template
declaration `OperandAdaptor<>` is provided to look up the operand adaptor for
the given operation.
and has the name of the operation followed by `Adaptor` as well as an alias
`Adaptor` inside the op class.
Operand adaptors can be used in function templates that also process operations:
@ -862,7 +861,7 @@ std::pair<Value, Value> zip(BinaryOpTy &&op) {
void process(AddOp op, ArrayRef<Value> newOperands) {
zip(op);
zip(OperandAdaptor<AddOp>(newOperands));
zip(Adaptor<AddOp>(newOperands));
/*...*/
}
```

View File

@ -124,7 +124,7 @@ struct TransposeOpLowering : public mlir::ConversionPattern {
// This allows for using the nice named accessors that are generated
// by the ODS. This adaptor is automatically provided by the ODS
// framework.
TransposeOpOperandAdaptor transposeAdaptor(memRefOperands);
TransposeOpAdaptor transposeAdaptor(memRefOperands);
mlir::Value input = transposeAdaptor.input();
// Transpose the elements by generating a load from the reverse

View File

@ -110,7 +110,7 @@ struct BinaryOpLowering : public ConversionPattern {
// Generate an adaptor for the remapped operands of the BinaryOp. This
// allows for using the nice named accessors that are generated by the
// ODS.
typename BinaryOp::OperandAdaptor binaryAdaptor(memRefOperands);
typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
// Generate loads for the element of 'lhs' and 'rhs' at the inner
// loop.
@ -234,7 +234,7 @@ struct TransposeOpLowering : public ConversionPattern {
// Generate an adaptor for the remapped operands of the TransposeOp.
// This allows for using the nice named accessors that are generated
// by the ODS.
toy::TransposeOpOperandAdaptor transposeAdaptor(memRefOperands);
toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
Value input = transposeAdaptor.input();
// Transpose the elements by generating a load from the reverse

View File

@ -110,7 +110,7 @@ struct BinaryOpLowering : public ConversionPattern {
// Generate an adaptor for the remapped operands of the BinaryOp. This
// allows for using the nice named accessors that are generated by the
// ODS.
typename BinaryOp::OperandAdaptor binaryAdaptor(memRefOperands);
typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
// Generate loads for the element of 'lhs' and 'rhs' at the inner
// loop.
@ -233,7 +233,7 @@ struct TransposeOpLowering : public ConversionPattern {
// Generate an adaptor for the remapped operands of the TransposeOp.
// This allows for using the nice named accessors that are generated
// by the ODS.
toy::TransposeOpOperandAdaptor transposeAdaptor(memRefOperands);
toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
Value input = transposeAdaptor.input();
// Transpose the elements by generating a load from the reverse

View File

@ -110,7 +110,7 @@ struct BinaryOpLowering : public ConversionPattern {
// Generate an adaptor for the remapped operands of the BinaryOp. This
// allows for using the nice named accessors that are generated by the
// ODS.
typename BinaryOp::OperandAdaptor binaryAdaptor(memRefOperands);
typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
// Generate loads for the element of 'lhs' and 'rhs' at the inner
// loop.
@ -234,7 +234,7 @@ struct TransposeOpLowering : public ConversionPattern {
// Generate an adaptor for the remapped operands of the TransposeOp.
// This allows for using the nice named accessors that are generated
// by the ODS.
toy::TransposeOpOperandAdaptor transposeAdaptor(memRefOperands);
toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
Value input = transposeAdaptor.input();
// Transpose the elements by generating a load from the reverse

View File

@ -47,14 +47,6 @@ class Value;
class ValueRange;
template <typename ValueRangeT> class ValueTypeRange;
/// This is an adaptor from a list of values to named operands of OpTy. In a
/// generic operation context, e.g., in dialect conversions, an ordered array of
/// `Value`s is treated as operands of `OpTy`. This adaptor takes a reference
/// to the array and provides accessors with the same names as `OpTy` for
/// operands. This makes possible to create function templates that operate on
/// either OpTy or OperandAdaptor<OpTy> seamlessly.
template <typename OpTy> using OperandAdaptor = typename OpTy::OperandAdaptor;
class OwningRewritePatternList;
//===----------------------------------------------------------------------===//

View File

@ -56,7 +56,7 @@ struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
gpu::ShuffleOpOperandAdaptor adaptor(operands);
gpu::ShuffleOpAdaptor adaptor(operands);
auto dialect = typeConverter.getDialect();
auto valueTy = adaptor.value().getType().cast<LLVM::LLVMType>();

View File

@ -140,7 +140,7 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
// latch and the merge block the exit block. The resulting spirv::LoopOp has a
// single back edge from the continue to header block, and a single exit from
// header to merge.
scf::ForOpOperandAdaptor forOperands(operands);
scf::ForOpAdaptor forOperands(operands);
auto loc = forOp.getLoc();
auto loopControl = rewriter.getI32IntegerAttr(
static_cast<uint32_t>(spirv::LoopControl::None));
@ -211,7 +211,7 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
// When lowering `scf::IfOp` we explicitly create a selection header block
// before the control flow diverges and a merge block where control flow
// subsequently converges.
scf::IfOpOperandAdaptor ifOperands(operands);
scf::IfOpAdaptor ifOperands(operands);
auto loc = ifOp.getLoc();
// Create `spv.selection` operation, selection header block and merge block.

View File

@ -140,7 +140,7 @@ public:
edsc::ScopedContext context(rewriter, op->getLoc());
// Fill in an aggregate value of the descriptor.
RangeOpOperandAdaptor adaptor(operands);
RangeOpAdaptor adaptor(operands);
Value desc = llvm_undef(rangeDescriptorTy);
desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
@ -178,7 +178,7 @@ public:
return failure();
edsc::ScopedContext context(rewriter, op->getLoc());
ReshapeOpOperandAdaptor adaptor(operands);
ReshapeOpAdaptor adaptor(operands);
BaseViewConversionHelper baseDesc(adaptor.src());
BaseViewConversionHelper desc(typeConverter.convertType(dstType));
desc.setAllocatedPtr(baseDesc.allocatedPtr());
@ -208,7 +208,7 @@ public:
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
edsc::ScopedContext context(rewriter, op->getLoc());
SliceOpOperandAdaptor adaptor(operands);
SliceOpAdaptor adaptor(operands);
BaseViewConversionHelper baseDesc(adaptor.view());
auto sliceOp = cast<SliceOp>(op);
@ -302,7 +302,7 @@ public:
ConversionPatternRewriter &rewriter) const override {
// Initialize the common boilerplate and alloca at the top of the FuncOp.
edsc::ScopedContext context(rewriter, op->getLoc());
TransposeOpOperandAdaptor adaptor(operands);
TransposeOpAdaptor adaptor(operands);
BaseViewConversionHelper baseDesc(adaptor.view());
auto transposeOp = cast<TransposeOp>(op);

View File

@ -28,7 +28,7 @@ public:
LogicalResult
matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
typename SrcOpTy::OperandAdaptor adaptor(operands);
typename SrcOpTy::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<DstOpTy>(op.getOperation(), adaptor.lhs(),
adaptor.rhs());
return success();
@ -43,7 +43,7 @@ public:
LogicalResult
matchAndRewrite(FromExtentTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
FromExtentTensorOpOperandAdaptor transformed(operands);
FromExtentTensorOp::Adaptor transformed(operands);
rewriter.replaceOp(op.getOperation(), transformed.input());
return success();
}
@ -56,7 +56,7 @@ public:
LogicalResult
matchAndRewrite(IndexToSizeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
IndexToSizeOpOperandAdaptor transformed(operands);
IndexToSizeOp::Adaptor transformed(operands);
rewriter.replaceOp(op.getOperation(), transformed.arg());
return success();
}
@ -69,7 +69,7 @@ public:
LogicalResult
matchAndRewrite(SizeToIndexOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
SizeToIndexOpOperandAdaptor transformed(operands);
SizeToIndexOp::Adaptor transformed(operands);
rewriter.replaceOp(op.getOperation(), transformed.arg());
return success();
}
@ -83,7 +83,7 @@ public:
LogicalResult
matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
ToExtentTensorOpOperandAdaptor transformed(operands);
ToExtentTensorOp::Adaptor transformed(operands);
rewriter.replaceOp(op.getOperation(), transformed.input());
return success();
}

View File

@ -1336,7 +1336,7 @@ struct CreateComplexOpLowering
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto complexOp = cast<CreateComplexOp>(op);
OperandAdaptor<CreateComplexOp> transformed(operands);
CreateComplexOp::Adaptor transformed(operands);
// Pack real and imaginary part in a complex number struct.
auto loc = op->getLoc();
@ -1356,7 +1356,7 @@ struct ReOpLowering : public ConvertOpToLLVMPattern<ReOp> {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<ReOp> transformed(operands);
ReOp::Adaptor transformed(operands);
// Extract real part from the complex number struct.
ComplexStructBuilder complexStruct(transformed.complex());
@ -1373,7 +1373,7 @@ struct ImOpLowering : public ConvertOpToLLVMPattern<ImOp> {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<ImOp> transformed(operands);
ImOp::Adaptor transformed(operands);
// Extract imaginary part from the complex number struct.
ComplexStructBuilder complexStruct(transformed.complex());
@ -1394,7 +1394,7 @@ unpackBinaryComplexOperands(OpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
auto bop = cast<OpTy>(op);
auto loc = bop.getLoc();
OperandAdaptor<OpTy> transformed(operands);
typename OpTy::Adaptor transformed(operands);
// Extract real and imaginary values from operands.
BinaryComplexOperands unpacked;
@ -1847,7 +1847,7 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<CallOpType> transformed(operands);
typename CallOpType::Adaptor transformed(operands);
auto callOp = cast<CallOpType>(op);
// Pack the result types into a struct.
@ -1919,7 +1919,7 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
assert(operands.size() == 1 && "dealloc takes one operand");
OperandAdaptor<DeallocOp> transformed(operands);
DeallocOp::Adaptor transformed(operands);
// Insert the `free` declaration if it is not already present.
auto freeFunc =
@ -1949,7 +1949,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<RsqrtOp> transformed(operands);
RsqrtOp::Adaptor transformed(operands);
auto operandType =
transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
@ -2029,7 +2029,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto memRefCastOp = cast<MemRefCastOp>(op);
OperandAdaptor<MemRefCastOp> transformed(operands);
MemRefCastOp::Adaptor transformed(operands);
auto srcType = memRefCastOp.getOperand().getType();
auto dstType = memRefCastOp.getType();
@ -2098,7 +2098,7 @@ struct DialectCastOpLowering
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto castOp = cast<LLVM::DialectCastOp>(op);
OperandAdaptor<LLVM::DialectCastOp> transformed(operands);
LLVM::DialectCastOp::Adaptor transformed(operands);
if (transformed.in().getType() !=
typeConverter.convertType(castOp.getType())) {
return failure();
@ -2117,7 +2117,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dimOp = cast<DimOp>(op);
OperandAdaptor<DimOp> transformed(operands);
DimOp::Adaptor transformed(operands);
MemRefType type = dimOp.memrefOrTensor().getType().cast<MemRefType>();
Optional<int64_t> index = dimOp.getConstantIndex();
@ -2163,7 +2163,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loadOp = cast<LoadOp>(op);
OperandAdaptor<LoadOp> transformed(operands);
LoadOp::Adaptor transformed(operands);
auto type = loadOp.getMemRefType();
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
@ -2182,7 +2182,7 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto type = cast<StoreOp>(op).getMemRefType();
OperandAdaptor<StoreOp> transformed(operands);
StoreOp::Adaptor transformed(operands);
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
transformed.indices(), rewriter, getModule());
@ -2201,7 +2201,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto prefetchOp = cast<PrefetchOp>(op);
OperandAdaptor<PrefetchOp> transformed(operands);
PrefetchOp::Adaptor transformed(operands);
auto type = prefetchOp.getMemRefType();
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
@ -2235,7 +2235,7 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
IndexCastOpOperandAdaptor transformed(operands);
IndexCastOpAdaptor transformed(operands);
auto indexCastOp = cast<IndexCastOp>(op);
auto targetType =
@ -2271,7 +2271,7 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto cmpiOp = cast<CmpIOp>(op);
CmpIOpOperandAdaptor transformed(operands);
CmpIOpAdaptor transformed(operands);
rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
op, typeConverter.convertType(cmpiOp.getResult().getType()),
@ -2290,7 +2290,7 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto cmpfOp = cast<CmpFOp>(op);
CmpFOpOperandAdaptor transformed(operands);
CmpFOpAdaptor transformed(operands);
rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
op, typeConverter.convertType(cmpfOp.getResult().getType()),
@ -2449,7 +2449,7 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto splatOp = cast<SplatOp>(op);
OperandAdaptor<SplatOp> adaptor(operands);
SplatOp::Adaptor adaptor(operands);
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
if (!resultType || resultType.getRank() == 1)
return failure();
@ -2647,7 +2647,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto viewOp = cast<ViewOp>(op);
ViewOpOperandAdaptor adaptor(operands);
ViewOpAdaptor adaptor(operands);
auto viewMemRefType = viewOp.getType();
auto targetElementTy =
@ -2721,7 +2721,7 @@ struct AssumeAlignmentOpLowering
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<AssumeAlignmentOp> transformed(operands);
AssumeAlignmentOp::Adaptor transformed(operands);
Value memref = transformed.memref();
unsigned alignment = cast<AssumeAlignmentOp>(op).alignment().getZExtValue();
@ -2791,7 +2791,7 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
auto maybeKind = matchSimpleAtomicOp(atomicOp);
if (!maybeKind)
return failure();
OperandAdaptor<AtomicRMWOp> adaptor(operands);
AtomicRMWOp::Adaptor adaptor(operands);
auto resultType = adaptor.value().getType();
auto memRefType = atomicOp.getMemRefType();
auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(),
@ -2840,7 +2840,7 @@ struct GenericAtomicRMWOpLowering
auto atomicOp = cast<GenericAtomicRMWOp>(op);
auto loc = op->getLoc();
OperandAdaptor<GenericAtomicRMWOp> adaptor(operands);
GenericAtomicRMWOp::Adaptor adaptor(operands);
LLVM::LLVMType valueType =
typeConverter.convertType(atomicOp.getResult().getType())
.cast<LLVM::LLVMType>();

View File

@ -653,7 +653,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
LogicalResult
CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
CmpFOpOperandAdaptor cmpFOpOperands(operands);
CmpFOpAdaptor cmpFOpOperands(operands);
switch (cmpFOp.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
@ -693,7 +693,7 @@ CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
LogicalResult
BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
CmpIOpOperandAdaptor cmpIOpOperands(operands);
CmpIOpAdaptor cmpIOpOperands(operands);
Type operandType = cmpIOp.lhs().getType();
if (!operandType.isa<IntegerType>() ||
@ -720,7 +720,7 @@ BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
LogicalResult
CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
CmpIOpOperandAdaptor cmpIOpOperands(operands);
CmpIOpAdaptor cmpIOpOperands(operands);
Type operandType = cmpIOp.lhs().getType();
if (operandType.isa<IntegerType>() &&
@ -763,7 +763,7 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
LogicalResult
IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
LoadOpOperandAdaptor loadOperands(operands);
LoadOpAdaptor loadOperands(operands);
auto loc = loadOp.getLoc();
auto memrefType = loadOp.memref().getType().cast<MemRefType>();
if (!memrefType.getElementType().isSignlessInteger())
@ -838,7 +838,7 @@ IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
LogicalResult
LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
LoadOpOperandAdaptor loadOperands(operands);
LoadOpAdaptor loadOperands(operands);
auto memrefType = loadOp.memref().getType().cast<MemRefType>();
if (memrefType.getElementType().isSignlessInteger())
return failure();
@ -870,7 +870,7 @@ ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
LogicalResult
SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
SelectOpOperandAdaptor selectOperands(operands);
SelectOpAdaptor selectOperands(operands);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
selectOperands.true_value(),
selectOperands.false_value());
@ -884,7 +884,7 @@ SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
LogicalResult
IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
StoreOpOperandAdaptor storeOperands(operands);
StoreOpAdaptor storeOperands(operands);
auto memrefType = storeOp.memref().getType().cast<MemRefType>();
if (!memrefType.getElementType().isSignlessInteger())
return failure();
@ -963,7 +963,7 @@ IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
LogicalResult
StoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
StoreOpOperandAdaptor storeOperands(operands);
StoreOpAdaptor storeOperands(operands);
auto memrefType = storeOp.memref().getType().cast<MemRefType>();
if (memrefType.getElementType().isSignlessInteger())
return failure();

View File

@ -176,7 +176,7 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc,
TransferWriteOp xferOp,
ArrayRef<Value> operands, Value dataPtr) {
auto adaptor = TransferWriteOpOperandAdaptor(operands);
auto adaptor = TransferWriteOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr);
return success();
}
@ -190,21 +190,21 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
return failure();
auto adaptor = TransferWriteOpOperandAdaptor(operands);
auto adaptor = TransferWriteOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
xferOp, adaptor.vector(), dataPtr, mask,
rewriter.getI32IntegerAttr(align));
return success();
}
static TransferReadOpOperandAdaptor
getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) {
return TransferReadOpOperandAdaptor(operands);
static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
ArrayRef<Value> operands) {
return TransferReadOpAdaptor(operands);
}
static TransferWriteOpOperandAdaptor
getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
return TransferWriteOpOperandAdaptor(operands);
static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
ArrayRef<Value> operands) {
return TransferWriteOpAdaptor(operands);
}
namespace {
@ -222,7 +222,7 @@ public:
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto matmulOp = cast<vector::MatmulOp>(op);
auto adaptor = vector::MatmulOpOperandAdaptor(operands);
auto adaptor = vector::MatmulOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
@ -244,7 +244,7 @@ public:
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto transOp = cast<vector::FlatTransposeOp>(op);
auto adaptor = vector::FlatTransposeOpOperandAdaptor(operands);
auto adaptor = vector::FlatTransposeOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
transOp, typeConverter.convertType(transOp.res().getType()),
adaptor.matrix(), transOp.rows(), transOp.columns());
@ -337,7 +337,7 @@ public:
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::ShuffleOpOperandAdaptor(operands);
auto adaptor = vector::ShuffleOpAdaptor(operands);
auto shuffleOp = cast<vector::ShuffleOp>(op);
auto v1Type = shuffleOp.getV1VectorType();
auto v2Type = shuffleOp.getV2VectorType();
@ -394,7 +394,7 @@ public:
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
auto adaptor = vector::ExtractElementOpAdaptor(operands);
auto extractEltOp = cast<vector::ExtractElementOp>(op);
auto vectorType = extractEltOp.getVectorType();
auto llvmType = typeConverter.convertType(vectorType.getElementType());
@ -420,7 +420,7 @@ public:
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::ExtractOpOperandAdaptor(operands);
auto adaptor = vector::ExtractOpAdaptor(operands);
auto extractOp = cast<vector::ExtractOp>(op);
auto vectorType = extractOp.getVectorType();
auto resultType = extractOp.getResult().getType();
@ -488,7 +488,7 @@ public:
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::FMAOpOperandAdaptor(operands);
auto adaptor = vector::FMAOpAdaptor(operands);
vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
VectorType vType = fmaOp.getVectorType();
if (vType.getRank() != 1)
@ -509,7 +509,7 @@ public:
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
auto adaptor = vector::InsertElementOpAdaptor(operands);
auto insertEltOp = cast<vector::InsertElementOp>(op);
auto vectorType = insertEltOp.getDestVectorType();
auto llvmType = typeConverter.convertType(vectorType);
@ -535,7 +535,7 @@ public:
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::InsertOpOperandAdaptor(operands);
auto adaptor = vector::InsertOpAdaptor(operands);
auto insertOp = cast<vector::InsertOp>(op);
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
@ -967,7 +967,7 @@ public:
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto printOp = cast<vector::PrintOp>(op);
auto adaptor = vector::PrintOpOperandAdaptor(operands);
auto adaptor = vector::PrintOpAdaptor(operands);
Type printType = printOp.getPrintType();
if (typeConverter.convertType(printType) == nullptr)

View File

@ -27,16 +27,6 @@
using namespace mlir;
using namespace mlir::vector;
static TransferReadOpOperandAdaptor
getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) {
return OperandAdaptor<TransferReadOp>(operands);
}
static TransferWriteOpOperandAdaptor
getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
return OperandAdaptor<TransferWriteOp>(operands);
}
static LogicalResult replaceTransferOpWithMubuf(
ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp,
@ -52,7 +42,7 @@ static LogicalResult replaceTransferOpWithMubuf(
LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp,
LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex,
Value &offsetSizeInBytes, Value &glc, Value &slc) {
auto adaptor = TransferWriteOpOperandAdaptor(operands);
auto adaptor = TransferWriteOpAdaptor(operands);
rewriter.replaceOpWithNewOp<ROCDL::MubufStoreOp>(xferOp, adaptor.vector(),
dwordConfig, vindex,
offsetSizeInBytes, glc, slc);
@ -76,7 +66,7 @@ public:
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto xferOp = cast<ConcreteOp>(op);
auto adaptor = getTransferOpAdapter(xferOp, operands);
typename ConcreteOp::Adaptor adaptor(operands);
if (xferOp.getVectorType().getRank() > 1 ||
llvm::size(xferOp.indices()) == 0)

View File

@ -779,7 +779,7 @@ static void print(OpAsmPrinter &p, GPUModuleOp op) {
/*printBlockTerminators=*/false);
}
// Namespace avoids ambiguous ReturnOpOperandAdaptor.
// Namespace avoids ambiguous ReturnOpAdaptor.
namespace mlir {
namespace gpu {
#define GET_OP_CLASSES

View File

@ -227,8 +227,7 @@ void tblgen::OpClass::writeDeclTo(raw_ostream &os) const {
os << ", " << trait;
os << "> {\npublic:\n";
os << " using Op::Op;\n";
os << " using OperandAdaptor = " << className << "OperandAdaptor;\n";
os << " using Adaptor = " << className << "OperandAdaptor;\n";
os << " using Adaptor = " << className << "Adaptor;\n";
bool hasPrivateMethod = false;
for (const auto &method : methods) {

View File

@ -60,7 +60,7 @@ std::string tblgen::Operator::getOperationName() const {
}
std::string tblgen::Operator::getAdaptorName() const {
return std::string(llvm::formatv("{0}OperandAdaptor", getCppClassName()));
return std::string(llvm::formatv("{0}Adaptor", getCppClassName()));
}
StringRef tblgen::Operator::getDialectName() const { return dialect.getName(); }

View File

@ -33,7 +33,7 @@ def AOp : NS_Op<"a_op", []> {
// Test verify method
// ---
// DEF: LogicalResult AOpOperandAdaptor::verify
// DEF: LogicalResult AOpAdaptor::verify
// DEF: auto tblgen_aAttr = odsAttrs.get("aAttr");
// DEF-NEXT: if (!tblgen_aAttr) return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'");
// DEF: if (!((some-condition))) return emitError(loc, "'test.a_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind");
@ -118,7 +118,7 @@ def BOp : NS_Op<"b_op", []> {
// Test common attribute kinds' constraints
// ---
// DEF-LABEL: BOpOperandAdaptor::verify
// DEF-LABEL: BOpAdaptor::verify
// DEF: if (!((true)))
// DEF: if (!((tblgen_bool_attr.isa<BoolAttr>())))
// DEF: if (!(((tblgen_i32_attr.isa<IntegerAttr>())) && ((tblgen_i32_attr.cast<IntegerAttr>().getType().isSignlessInteger(32)))))

View File

@ -47,9 +47,9 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// CHECK-LABEL: NS::AOp declarations
// CHECK: class AOpOperandAdaptor {
// CHECK: class AOpAdaptor {
// CHECK: public:
// CHECK: AOpOperandAdaptor(ValueRange values
// CHECK: AOpAdaptor(ValueRange values
// CHECK: ValueRange getODSOperands(unsigned index);
// CHECK: Value a();
// CHECK: ValueRange b();
@ -63,7 +63,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// CHECK-NOT: OpTrait::IsIsolatedFromAbove
// CHECK: public:
// CHECK: using Op::Op;
// CHECK: using OperandAdaptor = AOpOperandAdaptor;
// CHECK: using Adaptor = AOpAdaptor;
// CHECK: static StringRef getOperationName();
// CHECK: Operation::operand_range getODSOperands(unsigned index);
// CHECK: Value a();
@ -105,7 +105,7 @@ def NS_AttrSizedOperandOp : NS_Op<"attr_sized_operands",
);
}
// CHECK-LABEL: AttrSizedOperandOpOperandAdaptor(
// CHECK-LABEL: AttrSizedOperandOpAdaptor(
// CHECK-SAME: ValueRange values
// CHECK-SAME: DictionaryAttr attrs
// CHECK: ValueRange a();

View File

@ -14,7 +14,7 @@ def OpA : NS_Op<"one_normal_operand_op", []> {
// CHECK-LABEL: OpA definitions
// CHECK: OpAOperandAdaptor::OpAOperandAdaptor
// CHECK: OpAAdaptor::OpAAdaptor
// CHECK-SAME: odsOperands(values), odsAttrs(attrs)
// CHECK: void OpA::build
@ -39,13 +39,13 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
let arguments = (ins Variadic<AnyTensor>:$input1, AnyTensor:$input2, Variadic<AnyTensor>:$input3);
}
// CHECK-LABEL: ValueRange OpDOperandAdaptor::input1
// CHECK-LABEL: ValueRange OpDAdaptor::input1
// CHECK-NEXT: return getODSOperands(0);
// CHECK-LABEL: Value OpDOperandAdaptor::input2
// CHECK-LABEL: Value OpDAdaptor::input2
// CHECK-NEXT: return *getODSOperands(1).begin();
// CHECK-LABEL: ValueRange OpDOperandAdaptor::input3
// CHECK-LABEL: ValueRange OpDAdaptor::input3
// CHECK-NEXT: return getODSOperands(2);
// CHECK-LABEL: Operation::operand_range OpD::input1

View File

@ -32,7 +32,7 @@ def OpF : NS_Op<"op_for_int_min_val", []> {
let arguments = (ins Confined<I32Attr, [IntMinValue<10>]>:$attr);
}
// CHECK-LABEL: OpFOperandAdaptor::verify
// CHECK-LABEL: OpFAdaptor::verify
// CHECK: (tblgen_attr.cast<IntegerAttr>().getInt() >= 10)
// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 10"
@ -40,7 +40,7 @@ def OpFX : NS_Op<"op_for_int_max_val", []> {
let arguments = (ins Confined<I32Attr, [IntMaxValue<10>]>:$attr);
}
// CHECK-LABEL: OpFXOperandAdaptor::verify
// CHECK-LABEL: OpFXAdaptor::verify
// CHECK: (tblgen_attr.cast<IntegerAttr>().getInt() <= 10)
// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose maximum value is 10"
@ -48,7 +48,7 @@ def OpG : NS_Op<"op_for_arr_min_count", []> {
let arguments = (ins Confined<ArrayAttr, [ArrayMinCount<8>]>:$attr);
}
// CHECK-LABEL: OpGOperandAdaptor::verify
// CHECK-LABEL: OpGAdaptor::verify
// CHECK: (tblgen_attr.cast<ArrayAttr>().size() >= 8)
// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: array attribute with at least 8 elements"
@ -56,7 +56,7 @@ def OpH : NS_Op<"op_for_arr_value_at_index", []> {
let arguments = (ins Confined<ArrayAttr, [IntArrayNthElemEq<0, 8>]>:$attr);
}
// CHECK-LABEL: OpHOperandAdaptor::verify
// CHECK-LABEL: OpHAdaptor::verify
// CHECK: (((tblgen_attr.cast<ArrayAttr>().size() > 0)) && ((tblgen_attr.cast<ArrayAttr>()[0].cast<IntegerAttr>().getInt() == 8)))))
// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be 8"
@ -64,7 +64,7 @@ def OpI: NS_Op<"op_for_arr_min_value_at_index", []> {
let arguments = (ins Confined<ArrayAttr, [IntArrayNthElemMinValue<0, 8>]>:$attr);
}
// CHECK-LABEL: OpIOperandAdaptor::verify
// CHECK-LABEL: OpIAdaptor::verify
// CHECK: (((tblgen_attr.cast<ArrayAttr>().size() > 0)) && ((tblgen_attr.cast<ArrayAttr>()[0].cast<IntegerAttr>().getInt() >= 8)))))
// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 8"
@ -80,7 +80,7 @@ def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [
);
}
// CHECK-LABEL: OpJOperandAdaptor::verify
// CHECK-LABEL: OpJAdaptor::verify
// CHECK: llvm::is_splat(llvm::map_range(
// CHECK-SAME: llvm::ArrayRef<unsigned>({0, 2, 3}),
// CHECK-SAME: [this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); }))