forked from OSchip/llvm-project
[mlir] Flip accessors to prefixed form (NFC)
Another mechanical sweep to keep diff small for flip to _Prefixed.
This commit is contained in:
parent
f2e1d2cec0
commit
136d746ec7
|
@ -43,7 +43,7 @@ struct AllocOpLowering : public AllocLikeOpLLVMLowering {
|
|||
MemRefType memRefType = allocOp.getType();
|
||||
|
||||
Value alignment;
|
||||
if (auto alignmentAttr = allocOp.alignment()) {
|
||||
if (auto alignmentAttr = allocOp.getAlignment()) {
|
||||
alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
|
||||
} else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
|
||||
// In the case where no alignment is specified, we may want to override
|
||||
|
@ -124,7 +124,7 @@ struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
|
|||
/// aligned_alloc requires the allocation size to be a power of two, and the
|
||||
/// allocation size to be a multiple of alignment,
|
||||
int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
|
||||
if (Optional<uint64_t> alignment = allocOp.alignment())
|
||||
if (Optional<uint64_t> alignment = allocOp.getAlignment())
|
||||
return *alignment;
|
||||
|
||||
// Whenever we don't have alignment set, we will use an alignment
|
||||
|
@ -190,7 +190,7 @@ struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
|
|||
|
||||
auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
|
||||
loc, elementPtrType, sizeBytes,
|
||||
allocaOp.alignment() ? *allocaOp.alignment() : 0);
|
||||
allocaOp.getAlignment() ? *allocaOp.getAlignment() : 0);
|
||||
|
||||
return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
|
||||
}
|
||||
|
@ -223,9 +223,9 @@ struct AllocaScopeOpLowering
|
|||
}
|
||||
|
||||
// Inline body region.
|
||||
Block *beforeBody = &allocaScopeOp.bodyRegion().front();
|
||||
Block *afterBody = &allocaScopeOp.bodyRegion().back();
|
||||
rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock);
|
||||
Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
|
||||
Block *afterBody = &allocaScopeOp.getBodyRegion().back();
|
||||
rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
|
||||
|
||||
// Save stack and then branch into the body of the region.
|
||||
rewriter.setInsertionPointToEnd(currentBlock);
|
||||
|
@ -239,7 +239,7 @@ struct AllocaScopeOpLowering
|
|||
auto returnOp =
|
||||
cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
|
||||
auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
|
||||
returnOp, returnOp.results(), continueBlock);
|
||||
returnOp, returnOp.getResults(), continueBlock);
|
||||
|
||||
// Insert stack restore before jumping out the body of the region.
|
||||
rewriter.setInsertionPoint(branchOp);
|
||||
|
@ -260,8 +260,8 @@ struct AssumeAlignmentOpLowering
|
|||
LogicalResult
|
||||
matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value memref = adaptor.memref();
|
||||
unsigned alignment = op.alignment();
|
||||
Value memref = adaptor.getMemref();
|
||||
unsigned alignment = op.getAlignment();
|
||||
auto loc = op.getLoc();
|
||||
|
||||
MemRefDescriptor memRefDescriptor(memref);
|
||||
|
@ -305,7 +305,7 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Insert the `free` declaration if it is not already present.
|
||||
auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
|
||||
MemRefDescriptor memref(adaptor.memref());
|
||||
MemRefDescriptor memref(adaptor.getMemref());
|
||||
Value casted = rewriter.create<LLVM::BitcastOp>(
|
||||
op.getLoc(), getVoidPtrType(),
|
||||
memref.allocatedPtr(rewriter, op.getLoc()));
|
||||
|
@ -323,7 +323,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
|
|||
LogicalResult
|
||||
matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type operandType = dimOp.source().getType();
|
||||
Type operandType = dimOp.getSource().getType();
|
||||
if (operandType.isa<UnrankedMemRefType>()) {
|
||||
rewriter.replaceOp(
|
||||
dimOp, {extractSizeOfUnrankedMemRef(
|
||||
|
@ -354,7 +354,7 @@ private:
|
|||
// Extract pointer to the underlying ranked descriptor and bitcast it to a
|
||||
// memref<element_type> descriptor pointer to minimize the number of GEP
|
||||
// operations.
|
||||
UnrankedMemRefDescriptor unrankedDesc(adaptor.source());
|
||||
UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
|
||||
Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
|
||||
Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc,
|
||||
|
@ -375,7 +375,7 @@ private:
|
|||
// The size value that we have to extract can be obtained using GEPop with
|
||||
// `dimOp.index() + 1` index argument.
|
||||
Value idxPlusOne = rewriter.create<LLVM::AddOp>(
|
||||
loc, createIndexConstant(rewriter, loc, 1), adaptor.index());
|
||||
loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex());
|
||||
Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
|
||||
ValueRange({idxPlusOne}));
|
||||
return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
|
||||
|
@ -385,7 +385,7 @@ private:
|
|||
if (Optional<int64_t> idx = dimOp.getConstantIndex())
|
||||
return idx;
|
||||
|
||||
if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>())
|
||||
if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
|
||||
return constantOp.getValue()
|
||||
.cast<IntegerAttr>()
|
||||
.getValue()
|
||||
|
@ -405,16 +405,16 @@ private:
|
|||
int64_t i = *index;
|
||||
if (memRefType.isDynamicDim(i)) {
|
||||
// extract dynamic size from the memref descriptor.
|
||||
MemRefDescriptor descriptor(adaptor.source());
|
||||
MemRefDescriptor descriptor(adaptor.getSource());
|
||||
return descriptor.size(rewriter, loc, i);
|
||||
}
|
||||
// Use constant for static size.
|
||||
int64_t dimSize = memRefType.getDimSize(i);
|
||||
return createIndexConstant(rewriter, loc, dimSize);
|
||||
}
|
||||
Value index = adaptor.index();
|
||||
Value index = adaptor.getIndex();
|
||||
int64_t rank = memRefType.getRank();
|
||||
MemRefDescriptor memrefDescriptor(adaptor.source());
|
||||
MemRefDescriptor memrefDescriptor(adaptor.getSource());
|
||||
return memrefDescriptor.size(rewriter, loc, index, rank);
|
||||
}
|
||||
};
|
||||
|
@ -485,9 +485,9 @@ struct GenericAtomicRMWOpLowering
|
|||
|
||||
// Compute the loaded value and branch to the loop block.
|
||||
rewriter.setInsertionPointToEnd(initBlock);
|
||||
auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
|
||||
auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
|
||||
adaptor.indices(), rewriter);
|
||||
auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>();
|
||||
auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
|
||||
adaptor.getIndices(), rewriter);
|
||||
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
|
||||
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
|
||||
|
||||
|
@ -576,7 +576,7 @@ struct GlobalMemrefOpLowering
|
|||
LogicalResult
|
||||
matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MemRefType type = global.type();
|
||||
MemRefType type = global.getType();
|
||||
if (!isConvertibleAndHasIdentityMaps(type))
|
||||
return failure();
|
||||
|
||||
|
@ -587,7 +587,7 @@ struct GlobalMemrefOpLowering
|
|||
|
||||
Attribute initialValue = nullptr;
|
||||
if (!global.isExternal() && !global.isUninitialized()) {
|
||||
auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
|
||||
auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>();
|
||||
initialValue = elementsAttr;
|
||||
|
||||
// For scalar memrefs, the global variable created is of the element type,
|
||||
|
@ -596,10 +596,10 @@ struct GlobalMemrefOpLowering
|
|||
initialValue = elementsAttr.getSplatValue<Attribute>();
|
||||
}
|
||||
|
||||
uint64_t alignment = global.alignment().value_or(0);
|
||||
uint64_t alignment = global.getAlignment().value_or(0);
|
||||
|
||||
auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
|
||||
global, arrayTy, global.constant(), linkage, global.sym_name(),
|
||||
global, arrayTy, global.getConstant(), linkage, global.getSymName(),
|
||||
initialValue, alignment, type.getMemorySpaceAsInt());
|
||||
if (!global.isExternal() && global.isUninitialized()) {
|
||||
Block *blk = new Block();
|
||||
|
@ -627,12 +627,13 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
|
|||
Location loc, Value sizeBytes,
|
||||
Operation *op) const override {
|
||||
auto getGlobalOp = cast<memref::GetGlobalOp>(op);
|
||||
MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
|
||||
MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>();
|
||||
unsigned memSpace = type.getMemorySpaceAsInt();
|
||||
|
||||
Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
|
||||
auto addressOf = rewriter.create<LLVM::AddressOfOp>(
|
||||
loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
|
||||
loc, LLVM::LLVMPointerType::get(arrayTy, memSpace),
|
||||
getGlobalOp.getName());
|
||||
|
||||
// Get the address of the first element in the array by creating a GEP with
|
||||
// the address of the GV as the base, and (rank + 1) number of 0 indices.
|
||||
|
@ -670,8 +671,9 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto type = loadOp.getMemRefType();
|
||||
|
||||
Value dataPtr = getStridedElementPtr(
|
||||
loadOp.getLoc(), type, adaptor.memref(), adaptor.indices(), rewriter);
|
||||
Value dataPtr =
|
||||
getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
|
||||
adaptor.getIndices(), rewriter);
|
||||
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
|
||||
return success();
|
||||
}
|
||||
|
@ -687,9 +689,9 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto type = op.getMemRefType();
|
||||
|
||||
Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.memref(),
|
||||
adaptor.indices(), rewriter);
|
||||
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.value(), dataPtr);
|
||||
Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
|
||||
adaptor.getIndices(), rewriter);
|
||||
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -705,18 +707,19 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
|
|||
auto type = prefetchOp.getMemRefType();
|
||||
auto loc = prefetchOp.getLoc();
|
||||
|
||||
Value dataPtr = getStridedElementPtr(loc, type, adaptor.memref(),
|
||||
adaptor.indices(), rewriter);
|
||||
Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
|
||||
adaptor.getIndices(), rewriter);
|
||||
|
||||
// Replace with llvm.prefetch.
|
||||
auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
|
||||
auto isWrite = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
|
||||
loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.getIsWrite()));
|
||||
auto localityHint = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmI32Type,
|
||||
rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
|
||||
rewriter.getI32IntegerAttr(prefetchOp.getLocalityHint()));
|
||||
auto isData = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
|
||||
loc, llvmI32Type,
|
||||
rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache()));
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
|
||||
localityHint, isData);
|
||||
|
@ -731,9 +734,9 @@ struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
|
|||
matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Type operandType = op.memref().getType();
|
||||
Type operandType = op.getMemref().getType();
|
||||
if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
|
||||
UnrankedMemRefDescriptor desc(adaptor.memref());
|
||||
UnrankedMemRefDescriptor desc(adaptor.getMemref());
|
||||
rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
|
||||
return success();
|
||||
}
|
||||
|
@ -782,7 +785,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
|
|||
|
||||
// For ranked/ranked case, just keep the original descriptor.
|
||||
if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
|
||||
return rewriter.replaceOp(memRefCastOp, {adaptor.source()});
|
||||
return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
|
||||
|
||||
if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
|
||||
// Casting ranked to unranked memref type
|
||||
|
@ -793,7 +796,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
|
|||
int64_t rank = srcMemRefType.getRank();
|
||||
// ptr = AllocaOp sizeof(MemRefDescriptor)
|
||||
auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
|
||||
loc, adaptor.source(), rewriter);
|
||||
loc, adaptor.getSource(), rewriter);
|
||||
// voidptr = BitCastOp srcType* to void*
|
||||
auto voidPtr =
|
||||
rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
|
||||
|
@ -814,7 +817,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
|
|||
// Casting from unranked type to ranked.
|
||||
// The operation is assumed to be doing a correct cast. If the destination
|
||||
// type mismatches the unranked the type, it is undefined behavior.
|
||||
UnrankedMemRefDescriptor memRefDesc(adaptor.source());
|
||||
UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
|
||||
// ptr = ExtractValueOp src, 1
|
||||
auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
|
||||
// castPtr = BitCastOp i8* to structTy*
|
||||
|
@ -844,9 +847,9 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
|
|||
lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
auto srcType = op.source().getType().dyn_cast<MemRefType>();
|
||||
auto srcType = op.getSource().getType().dyn_cast<MemRefType>();
|
||||
|
||||
MemRefDescriptor srcDesc(adaptor.source());
|
||||
MemRefDescriptor srcDesc(adaptor.getSource());
|
||||
|
||||
// Compute number of elements.
|
||||
Value numElements = rewriter.create<LLVM::ConstantOp>(
|
||||
|
@ -866,7 +869,7 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
|
|||
Value srcOffset = srcDesc.offset(rewriter, loc);
|
||||
Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(),
|
||||
srcBasePtr, srcOffset);
|
||||
MemRefDescriptor targetDesc(adaptor.target());
|
||||
MemRefDescriptor targetDesc(adaptor.getTarget());
|
||||
Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
|
||||
Value targetOffset = targetDesc.offset(rewriter, loc);
|
||||
Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
|
||||
|
@ -885,8 +888,8 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
|
|||
lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
auto srcType = op.source().getType().cast<BaseMemRefType>();
|
||||
auto targetType = op.target().getType().cast<BaseMemRefType>();
|
||||
auto srcType = op.getSource().getType().cast<BaseMemRefType>();
|
||||
auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
|
||||
|
||||
// First make sure we have an unranked memref descriptor representation.
|
||||
auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
|
||||
|
@ -906,11 +909,11 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
|
|||
};
|
||||
|
||||
Value unrankedSource = srcType.hasRank()
|
||||
? makeUnranked(adaptor.source(), srcType)
|
||||
: adaptor.source();
|
||||
? makeUnranked(adaptor.getSource(), srcType)
|
||||
: adaptor.getSource();
|
||||
Value unrankedTarget = targetType.hasRank()
|
||||
? makeUnranked(adaptor.target(), targetType)
|
||||
: adaptor.target();
|
||||
? makeUnranked(adaptor.getTarget(), targetType)
|
||||
: adaptor.getTarget();
|
||||
|
||||
// Now promote the unranked descriptors to the stack.
|
||||
auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
|
||||
|
@ -942,8 +945,8 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
|
|||
LogicalResult
|
||||
matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcType = op.source().getType().cast<BaseMemRefType>();
|
||||
auto targetType = op.target().getType().cast<BaseMemRefType>();
|
||||
auto srcType = op.getSource().getType().cast<BaseMemRefType>();
|
||||
auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
|
||||
|
||||
auto isContiguousMemrefType = [](BaseMemRefType type) {
|
||||
auto memrefType = type.dyn_cast<mlir::MemRefType>();
|
||||
|
@ -1013,7 +1016,7 @@ struct MemRefReinterpretCastOpLowering
|
|||
LogicalResult
|
||||
matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type srcType = castOp.source().getType();
|
||||
Type srcType = castOp.getSource().getType();
|
||||
|
||||
Value descriptor;
|
||||
if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
|
||||
|
@ -1042,14 +1045,14 @@ private:
|
|||
// Set allocated and aligned pointers.
|
||||
Value allocatedPtr, alignedPtr;
|
||||
extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
|
||||
castOp.source(), adaptor.source(), &allocatedPtr,
|
||||
&alignedPtr);
|
||||
castOp.getSource(), adaptor.getSource(),
|
||||
&allocatedPtr, &alignedPtr);
|
||||
desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
|
||||
desc.setAlignedPtr(rewriter, loc, alignedPtr);
|
||||
|
||||
// Set offset.
|
||||
if (castOp.isDynamicOffset(0))
|
||||
desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
|
||||
desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
|
||||
else
|
||||
desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
|
||||
|
||||
|
@ -1058,12 +1061,12 @@ private:
|
|||
unsigned dynStrideId = 0;
|
||||
for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
|
||||
if (castOp.isDynamicSize(i))
|
||||
desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
|
||||
desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
|
||||
else
|
||||
desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
|
||||
|
||||
if (castOp.isDynamicStride(i))
|
||||
desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
|
||||
desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
|
||||
else
|
||||
desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
|
||||
}
|
||||
|
@ -1079,7 +1082,7 @@ struct MemRefReshapeOpLowering
|
|||
LogicalResult
|
||||
matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type srcType = reshapeOp.source().getType();
|
||||
Type srcType = reshapeOp.getSource().getType();
|
||||
|
||||
Value descriptor;
|
||||
if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
|
||||
|
@ -1095,7 +1098,7 @@ private:
|
|||
Type srcType, memref::ReshapeOp reshapeOp,
|
||||
memref::ReshapeOp::Adaptor adaptor,
|
||||
Value *descriptor) const {
|
||||
auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
|
||||
auto shapeMemRefType = reshapeOp.getShape().getType().cast<MemRefType>();
|
||||
if (shapeMemRefType.hasStaticShape()) {
|
||||
MemRefType targetMemRefType =
|
||||
reshapeOp.getResult().getType().cast<MemRefType>();
|
||||
|
@ -1113,7 +1116,7 @@ private:
|
|||
// Set allocated and aligned pointers.
|
||||
Value allocatedPtr, alignedPtr;
|
||||
extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
|
||||
reshapeOp.source(), adaptor.source(),
|
||||
reshapeOp.getSource(), adaptor.getSource(),
|
||||
&allocatedPtr, &alignedPtr);
|
||||
desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
|
||||
desc.setAlignedPtr(rewriter, loc, alignedPtr);
|
||||
|
@ -1155,7 +1158,7 @@ private:
|
|||
if (!ShapedType::isDynamic(size)) {
|
||||
dimSize = createIndexConstant(rewriter, loc, size);
|
||||
} else {
|
||||
Value shapeOp = reshapeOp.shape();
|
||||
Value shapeOp = reshapeOp.getShape();
|
||||
Value index = createIndexConstant(rewriter, loc, i);
|
||||
dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
|
||||
}
|
||||
|
@ -1173,7 +1176,7 @@ private:
|
|||
|
||||
// The shape is a rank-1 tensor with unknown length.
|
||||
Location loc = reshapeOp.getLoc();
|
||||
MemRefDescriptor shapeDesc(adaptor.shape());
|
||||
MemRefDescriptor shapeDesc(adaptor.getShape());
|
||||
Value resultRank = shapeDesc.size(rewriter, loc, 0);
|
||||
|
||||
// Extract address space and element type.
|
||||
|
@ -1197,7 +1200,7 @@ private:
|
|||
// Extract pointers and offset from the source memref.
|
||||
Value allocatedPtr, alignedPtr, offset;
|
||||
extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
|
||||
reshapeOp.source(), adaptor.source(),
|
||||
reshapeOp.getSource(), adaptor.getSource(),
|
||||
&allocatedPtr, &alignedPtr, &offset);
|
||||
|
||||
// Set pointers and offset.
|
||||
|
@ -1555,7 +1558,7 @@ public:
|
|||
reshapeOp, "failed to get stride and offset exprs");
|
||||
}
|
||||
|
||||
MemRefDescriptor srcDesc(adaptor.src());
|
||||
MemRefDescriptor srcDesc(adaptor.getSrc());
|
||||
Location loc = reshapeOp->getLoc();
|
||||
auto dstDesc = MemRefDescriptor::undef(
|
||||
rewriter, loc, this->typeConverter->convertType(dstType));
|
||||
|
@ -1611,17 +1614,18 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = subViewOp.getLoc();
|
||||
|
||||
auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
|
||||
auto sourceMemRefType = subViewOp.getSource().getType().cast<MemRefType>();
|
||||
auto sourceElementTy =
|
||||
typeConverter->convertType(sourceMemRefType.getElementType());
|
||||
|
||||
auto viewMemRefType = subViewOp.getType();
|
||||
auto inferredType = memref::SubViewOp::inferResultType(
|
||||
subViewOp.getSourceType(),
|
||||
extractFromI64ArrayAttr(subViewOp.static_offsets()),
|
||||
extractFromI64ArrayAttr(subViewOp.static_sizes()),
|
||||
extractFromI64ArrayAttr(subViewOp.static_strides()))
|
||||
.cast<MemRefType>();
|
||||
auto inferredType =
|
||||
memref::SubViewOp::inferResultType(
|
||||
subViewOp.getSourceType(),
|
||||
extractFromI64ArrayAttr(subViewOp.getStaticOffsets()),
|
||||
extractFromI64ArrayAttr(subViewOp.getStaticSizes()),
|
||||
extractFromI64ArrayAttr(subViewOp.getStaticStrides()))
|
||||
.cast<MemRefType>();
|
||||
auto targetElementTy =
|
||||
typeConverter->convertType(viewMemRefType.getElementType());
|
||||
auto targetDescTy = typeConverter->convertType(viewMemRefType);
|
||||
|
@ -1717,7 +1721,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
|
|||
// aware of LLVM constants and for this pass to be aware of std
|
||||
// constants.
|
||||
int64_t staticSize =
|
||||
subViewOp.source().getType().cast<MemRefType>().getShape()[i];
|
||||
subViewOp.getSource().getType().cast<MemRefType>().getShape()[i];
|
||||
if (staticSize != ShapedType::kDynamicSize) {
|
||||
size = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
|
||||
|
@ -1725,7 +1729,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
|
|||
Value pos = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
|
||||
Value dim =
|
||||
rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos);
|
||||
rewriter.create<memref::DimOp>(loc, subViewOp.getSource(), pos);
|
||||
auto cast = rewriter.create<UnrealizedConversionCastOp>(
|
||||
loc, llvmIndexType, dim);
|
||||
size = cast.getResult(0);
|
||||
|
@ -1779,10 +1783,10 @@ public:
|
|||
matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = transposeOp.getLoc();
|
||||
MemRefDescriptor viewMemRef(adaptor.in());
|
||||
MemRefDescriptor viewMemRef(adaptor.getIn());
|
||||
|
||||
// No permutation, early exit.
|
||||
if (transposeOp.permutation().isIdentity())
|
||||
if (transposeOp.getPermutation().isIdentity())
|
||||
return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
|
||||
|
||||
auto targetMemRef = MemRefDescriptor::undef(
|
||||
|
@ -1800,7 +1804,7 @@ public:
|
|||
|
||||
// Iterate over the dimensions and apply size/stride permutation.
|
||||
for (const auto &en :
|
||||
llvm::enumerate(transposeOp.permutation().getResults())) {
|
||||
llvm::enumerate(transposeOp.getPermutation().getResults())) {
|
||||
int sourcePos = en.index();
|
||||
int targetPos = en.value().cast<AffineDimExpr>().getPosition();
|
||||
targetMemRef.setSize(rewriter, loc, targetPos,
|
||||
|
@ -1884,12 +1888,12 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
|
|||
failure();
|
||||
|
||||
// Create the descriptor.
|
||||
MemRefDescriptor sourceMemRef(adaptor.source());
|
||||
MemRefDescriptor sourceMemRef(adaptor.getSource());
|
||||
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
|
||||
|
||||
// Field 1: Copy the allocated pointer, used for malloc/free.
|
||||
Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
|
||||
auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
|
||||
auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>();
|
||||
Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc,
|
||||
LLVM::LLVMPointerType::get(targetElementTy,
|
||||
|
@ -1899,8 +1903,8 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
|
|||
|
||||
// Field 2: Copy the actual aligned pointer to payload.
|
||||
Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
|
||||
alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
|
||||
alignedPtr, adaptor.byte_shift());
|
||||
alignedPtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, alignedPtr.getType(), alignedPtr, adaptor.getByteShift());
|
||||
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc,
|
||||
LLVM::LLVMPointerType::get(targetElementTy,
|
||||
|
@ -1922,8 +1926,8 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
|
|||
Value stride = nullptr, nextSize = nullptr;
|
||||
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
|
||||
// Update size.
|
||||
Value size =
|
||||
getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i);
|
||||
Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
|
||||
adaptor.getSizes(), i);
|
||||
targetMemRef.setSize(rewriter, loc, i, size);
|
||||
// Update stride.
|
||||
stride = getStride(rewriter, loc, strides, nextSize, stride, i);
|
||||
|
@ -1944,7 +1948,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
|
|||
/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
|
||||
static Optional<LLVM::AtomicBinOp>
|
||||
matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
|
||||
switch (atomicOp.kind()) {
|
||||
switch (atomicOp.getKind()) {
|
||||
case arith::AtomicRMWKind::addf:
|
||||
return LLVM::AtomicBinOp::fadd;
|
||||
case arith::AtomicRMWKind::addi:
|
||||
|
@ -1980,13 +1984,13 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
|
|||
auto maybeKind = matchSimpleAtomicOp(atomicOp);
|
||||
if (!maybeKind)
|
||||
return failure();
|
||||
auto resultType = adaptor.value().getType();
|
||||
auto resultType = adaptor.getValue().getType();
|
||||
auto memRefType = atomicOp.getMemRefType();
|
||||
auto dataPtr =
|
||||
getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
|
||||
adaptor.indices(), rewriter);
|
||||
getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
|
||||
adaptor.getIndices(), rewriter);
|
||||
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
|
||||
atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
|
||||
atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(),
|
||||
LLVM::AtomicOrdering::acq_rel);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -303,7 +303,7 @@ LogicalResult
|
|||
DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
|
||||
MemRefType deallocType = operation.getMemref().getType().cast<MemRefType>();
|
||||
if (!isAllocationSupported(operation, deallocType))
|
||||
return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
|
||||
rewriter.eraseOp(operation);
|
||||
|
@ -318,14 +318,14 @@ LogicalResult
|
|||
IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = loadOp.getLoc();
|
||||
auto memrefType = loadOp.memref().getType().cast<MemRefType>();
|
||||
auto memrefType = loadOp.getMemref().getType().cast<MemRefType>();
|
||||
if (!memrefType.getElementType().isSignlessInteger())
|
||||
return failure();
|
||||
|
||||
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
|
||||
spirv::AccessChainOp accessChainOp =
|
||||
spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(),
|
||||
adaptor.indices(), loc, rewriter);
|
||||
spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
|
||||
adaptor.getIndices(), loc, rewriter);
|
||||
|
||||
if (!accessChainOp)
|
||||
return failure();
|
||||
|
@ -413,12 +413,12 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
|
|||
LogicalResult
|
||||
LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto memrefType = loadOp.memref().getType().cast<MemRefType>();
|
||||
auto memrefType = loadOp.getMemref().getType().cast<MemRefType>();
|
||||
if (memrefType.getElementType().isSignlessInteger())
|
||||
return failure();
|
||||
auto loadPtr = spirv::getElementPtr(
|
||||
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(),
|
||||
adaptor.indices(), loadOp.getLoc(), rewriter);
|
||||
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
|
||||
adaptor.getIndices(), loadOp.getLoc(), rewriter);
|
||||
|
||||
if (!loadPtr)
|
||||
return failure();
|
||||
|
@ -430,15 +430,15 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
|
|||
LogicalResult
|
||||
IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto memrefType = storeOp.memref().getType().cast<MemRefType>();
|
||||
auto memrefType = storeOp.getMemref().getType().cast<MemRefType>();
|
||||
if (!memrefType.getElementType().isSignlessInteger())
|
||||
return failure();
|
||||
|
||||
auto loc = storeOp.getLoc();
|
||||
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
|
||||
spirv::AccessChainOp accessChainOp =
|
||||
spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(),
|
||||
adaptor.indices(), loc, rewriter);
|
||||
spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
|
||||
adaptor.getIndices(), loc, rewriter);
|
||||
|
||||
if (!accessChainOp)
|
||||
return failure();
|
||||
|
@ -463,7 +463,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
|
|||
assert(dstBits % srcBits == 0);
|
||||
|
||||
if (srcBits == dstBits) {
|
||||
Value storeVal = adaptor.value();
|
||||
Value storeVal = adaptor.getValue();
|
||||
if (isBool)
|
||||
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
|
||||
rewriter.replaceOpWithNewOp<spirv::StoreOp>(
|
||||
|
@ -494,7 +494,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
|
|||
rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
|
||||
clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
|
||||
|
||||
Value storeVal = adaptor.value();
|
||||
Value storeVal = adaptor.getValue();
|
||||
if (isBool)
|
||||
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
|
||||
storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
|
||||
|
@ -525,18 +525,18 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
|
|||
LogicalResult
|
||||
StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto memrefType = storeOp.memref().getType().cast<MemRefType>();
|
||||
auto memrefType = storeOp.getMemref().getType().cast<MemRefType>();
|
||||
if (memrefType.getElementType().isSignlessInteger())
|
||||
return failure();
|
||||
auto storePtr = spirv::getElementPtr(
|
||||
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(),
|
||||
adaptor.indices(), storeOp.getLoc(), rewriter);
|
||||
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
|
||||
adaptor.getIndices(), storeOp.getLoc(), rewriter);
|
||||
|
||||
if (!storePtr)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
|
||||
adaptor.value());
|
||||
adaptor.getValue());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -357,7 +357,7 @@ struct Strategy<TransferReadOp> {
|
|||
static void getBufferIndices(TransferReadOp xferOp,
|
||||
SmallVector<Value, 8> &indices) {
|
||||
auto storeOp = getStoreOp(xferOp);
|
||||
auto prevIndices = memref::StoreOpAdaptor(storeOp).indices();
|
||||
auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
|
||||
indices.append(prevIndices.begin(), prevIndices.end());
|
||||
}
|
||||
|
||||
|
@ -463,7 +463,7 @@ struct Strategy<TransferWriteOp> {
|
|||
static void getBufferIndices(TransferWriteOp xferOp,
|
||||
SmallVector<Value, 8> &indices) {
|
||||
auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
|
||||
auto prevIndices = memref::LoadOpAdaptor(loadOp).indices();
|
||||
auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
|
||||
indices.append(prevIndices.begin(), prevIndices.end());
|
||||
}
|
||||
|
||||
|
|
|
@ -299,7 +299,7 @@ bool mlir::isValidDim(Value value, Region *region) {
|
|||
// The dim op is okay if its operand memref/tensor is defined at the top
|
||||
// level.
|
||||
if (auto dimOp = dyn_cast<memref::DimOp>(op))
|
||||
return isTopLevelValue(dimOp.source());
|
||||
return isTopLevelValue(dimOp.getSource());
|
||||
if (auto dimOp = dyn_cast<tensor::DimOp>(op))
|
||||
return isTopLevelValue(dimOp.getSource());
|
||||
return false;
|
||||
|
@ -2534,7 +2534,7 @@ OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
|
|||
if (!symbolTableOp)
|
||||
return {};
|
||||
auto global = dyn_cast_or_null<memref::GlobalOp>(
|
||||
SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.nameAttr()));
|
||||
SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
|
||||
if (!global)
|
||||
return {};
|
||||
|
||||
|
|
|
@ -1635,7 +1635,7 @@ static void createNewDynamicSizes(MemRefType oldMemRefType,
|
|||
for (unsigned d = 0; d < oldMemRefType.getRank(); ++d) {
|
||||
if (oldMemRefShape[d] < 0) {
|
||||
// Use dynamicSizes of allocOp for dynamic dimension.
|
||||
inAffineApply.emplace_back(allocOp->dynamicSizes()[dynIdx]);
|
||||
inAffineApply.emplace_back(allocOp->getDynamicSizes()[dynIdx]);
|
||||
dynIdx++;
|
||||
} else {
|
||||
// Create ConstantOp for static dimension.
|
||||
|
@ -1681,7 +1681,7 @@ LogicalResult mlir::normalizeMemRef(memref::AllocOp *allocOp) {
|
|||
// Fetch a new memref type after normalizing the old memref to have an
|
||||
// identity map layout.
|
||||
MemRefType newMemRefType =
|
||||
normalizeMemRefType(memrefType, b, allocOp->symbolOperands().size());
|
||||
normalizeMemRefType(memrefType, b, allocOp->getSymbolOperands().size());
|
||||
if (newMemRefType == memrefType)
|
||||
// Either memrefType already had an identity map or the map couldn't be
|
||||
// transformed to an identity map.
|
||||
|
@ -1689,7 +1689,7 @@ LogicalResult mlir::normalizeMemRef(memref::AllocOp *allocOp) {
|
|||
|
||||
Value oldMemRef = allocOp->getResult();
|
||||
|
||||
SmallVector<Value, 4> symbolOperands(allocOp->symbolOperands());
|
||||
SmallVector<Value, 4> symbolOperands(allocOp->getSymbolOperands());
|
||||
AffineMap layoutMap = memrefType.getLayout().getAffineMap();
|
||||
memref::AllocOp newAlloc;
|
||||
// Check if `layoutMap` is a tiled layout. Only single layout map is
|
||||
|
@ -1704,10 +1704,10 @@ LogicalResult mlir::normalizeMemRef(memref::AllocOp *allocOp) {
|
|||
// Add the new dynamic sizes in new AllocOp.
|
||||
newAlloc =
|
||||
b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
|
||||
newDynamicSizes, allocOp->alignmentAttr());
|
||||
newDynamicSizes, allocOp->getAlignmentAttr());
|
||||
} else {
|
||||
newAlloc = b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
|
||||
allocOp->alignmentAttr());
|
||||
allocOp->getAlignmentAttr());
|
||||
}
|
||||
// Replace all uses of the old memref.
|
||||
if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
|
||||
|
|
|
@ -48,7 +48,7 @@ struct ConstantOpInterface
|
|||
return failure();
|
||||
memref::GlobalOp globalMemref = *globalOp;
|
||||
replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
|
||||
rewriter, op, globalMemref.type(), globalMemref.getName());
|
||||
rewriter, op, globalMemref.getType(), globalMemref.getName());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -615,12 +615,12 @@ struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(memref::LoadOp load,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto toMemref = load.memref().getDefiningOp<ToMemrefOp>();
|
||||
auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>();
|
||||
if (!toMemref)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(),
|
||||
load.indices());
|
||||
load.getIndices());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -631,11 +631,12 @@ struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(memref::DimOp dimOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>();
|
||||
auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>();
|
||||
if (!castOp)
|
||||
return failure();
|
||||
Value newSource = castOp.getOperand();
|
||||
rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index());
|
||||
rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
|
||||
dimOp.getIndex());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -131,8 +131,9 @@ public:
|
|||
SmallVector<Value> allocsAndAllocas;
|
||||
for (BufferPlacementAllocs::AllocEntry &entry : allocs)
|
||||
allocsAndAllocas.push_back(std::get<0>(entry));
|
||||
scopeOp->walk(
|
||||
[&](memref::AllocaOp op) { allocsAndAllocas.push_back(op.memref()); });
|
||||
scopeOp->walk([&](memref::AllocaOp op) {
|
||||
allocsAndAllocas.push_back(op.getMemref());
|
||||
});
|
||||
|
||||
for (auto allocValue : allocsAndAllocas) {
|
||||
if (!StateT::shouldHoistOpType(allocValue.getDefiningOp()))
|
||||
|
|
|
@ -158,11 +158,12 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment) {
|
|||
auto globalOp = dyn_cast<memref::GlobalOp>(&op);
|
||||
if (!globalOp)
|
||||
continue;
|
||||
if (!globalOp.initial_value().hasValue())
|
||||
if (!globalOp.getInitialValue().hasValue())
|
||||
continue;
|
||||
uint64_t opAlignment =
|
||||
globalOp.alignment().hasValue() ? globalOp.alignment().getValue() : 0;
|
||||
Attribute initialValue = globalOp.initial_value().getValue();
|
||||
uint64_t opAlignment = globalOp.getAlignment().hasValue()
|
||||
? globalOp.getAlignment().getValue()
|
||||
: 0;
|
||||
Attribute initialValue = globalOp.getInitialValue().getValue();
|
||||
if (opAlignment == alignment && initialValue == constantOp.getValue())
|
||||
return globalOp;
|
||||
}
|
||||
|
|
|
@ -80,7 +80,7 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
|
|||
for (BlockArgument bbArg : funcOp.getArguments()) {
|
||||
Value val = it.value();
|
||||
while (auto castOp = val.getDefiningOp<memref::CastOp>())
|
||||
val = castOp.source();
|
||||
val = castOp.getSource();
|
||||
|
||||
if (val == bbArg) {
|
||||
resultToArgs[it.index()] = bbArg.getArgNumber();
|
||||
|
|
|
@ -359,8 +359,8 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
|
|||
|
||||
for (OpOperand &operand : returnOp->getOpOperands()) {
|
||||
if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
|
||||
operand.set(castOp.source());
|
||||
resultTypes.push_back(castOp.source().getType());
|
||||
operand.set(castOp.getSource());
|
||||
resultTypes.push_back(castOp.getSource().getType());
|
||||
} else {
|
||||
resultTypes.push_back(operand.get().getType());
|
||||
}
|
||||
|
|
|
@ -1372,15 +1372,15 @@ struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(memref::DimOp dimOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto index = dimOp.index().getDefiningOp<arith::ConstantIndexOp>();
|
||||
auto index = dimOp.getIndex().getDefiningOp<arith::ConstantIndexOp>();
|
||||
if (!index)
|
||||
return failure();
|
||||
|
||||
auto memrefType = dimOp.source().getType().dyn_cast<MemRefType>();
|
||||
auto memrefType = dimOp.getSource().getType().dyn_cast<MemRefType>();
|
||||
if (!memrefType || !memrefType.isDynamicDim(index.value()))
|
||||
return failure();
|
||||
|
||||
auto alloc = dimOp.source().getDefiningOp<AllocOp>();
|
||||
auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
|
||||
if (!alloc)
|
||||
return failure();
|
||||
|
||||
|
|
|
@ -108,7 +108,8 @@ defaultDeallocBufferCallBack(const LinalgPromotionOptions &options,
|
|||
OpBuilder &b, Value fullLocalView) {
|
||||
if (!options.useAlloca) {
|
||||
auto viewOp = cast<memref::ViewOp>(fullLocalView.getDefiningOp());
|
||||
b.create<memref::DeallocOp>(viewOp.source().getLoc(), viewOp.source());
|
||||
b.create<memref::DeallocOp>(viewOp.getSource().getLoc(),
|
||||
viewOp.getSource());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -625,8 +625,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
|
|||
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
|
||||
memref::CopyOp copyOp) {
|
||||
|
||||
auto srcType = copyOp.source().getType().cast<MemRefType>();
|
||||
auto dstType = copyOp.target().getType().cast<MemRefType>();
|
||||
auto srcType = copyOp.getSource().getType().cast<MemRefType>();
|
||||
auto dstType = copyOp.getTarget().getType().cast<MemRefType>();
|
||||
if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
|
@ -640,14 +640,14 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
|
|||
SmallVector<Value> indices(srcType.getRank(), zero);
|
||||
|
||||
Value readValue = rewriter.create<vector::TransferReadOp>(
|
||||
loc, readType, copyOp.source(), indices,
|
||||
loc, readType, copyOp.getSource(), indices,
|
||||
rewriter.getMultiDimIdentityMap(srcType.getRank()));
|
||||
if (readValue.getType().cast<VectorType>().getRank() == 0) {
|
||||
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
|
||||
readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
|
||||
}
|
||||
Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
|
||||
loc, readValue, copyOp.target(), indices,
|
||||
loc, readValue, copyOp.getTarget(), indices,
|
||||
rewriter.getMultiDimIdentityMap(srcType.getRank()));
|
||||
rewriter.replaceOp(copyOp, writeValue->getResults());
|
||||
return success();
|
||||
|
@ -1168,8 +1168,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
|||
memref::CopyOp copyOp;
|
||||
for (auto &u : subView.getUses()) {
|
||||
if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
|
||||
assert(newCopyOp.target().getType().isa<MemRefType>());
|
||||
if (newCopyOp.target() != subView)
|
||||
assert(newCopyOp.getTarget().getType().isa<MemRefType>());
|
||||
if (newCopyOp.getTarget() != subView)
|
||||
continue;
|
||||
LDBG("copy candidate " << *newCopyOp);
|
||||
if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
|
||||
|
@ -1204,7 +1204,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
|||
LDBG("with maybeFillOp " << *maybeFillOp);
|
||||
|
||||
// `in` is the subview that memref.copy reads. Replace it.
|
||||
Value in = copyOp.source();
|
||||
Value in = copyOp.getSource();
|
||||
|
||||
// memref.copy + linalg.fill can be used to create a padded local buffer.
|
||||
// The `masked` attribute is only valid on this padded buffer.
|
||||
|
@ -1248,7 +1248,7 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
|
|||
memref::CopyOp copyOp;
|
||||
for (auto &u : subViewOp.getResult().getUses()) {
|
||||
if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
|
||||
if (newCopyOp.source() != subView)
|
||||
if (newCopyOp.getSource() != subView)
|
||||
continue;
|
||||
if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
|
||||
continue;
|
||||
|
@ -1260,8 +1260,8 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
|
|||
return failure();
|
||||
|
||||
// `out` is the subview copied into that we replace.
|
||||
assert(copyOp.target().getType().isa<MemRefType>());
|
||||
Value out = copyOp.target();
|
||||
assert(copyOp.getTarget().getType().isa<MemRefType>());
|
||||
Value out = copyOp.getTarget();
|
||||
|
||||
// Forward vector.transfer into copy.
|
||||
// memref.copy + linalg.fill can be used to create a padded local buffer.
|
||||
|
|
|
@ -980,10 +980,10 @@ SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
|
|||
Value outputTensor = operands[opOperand->getOperandNumber()];
|
||||
if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
|
||||
Value inserted = builder.create<tensor::InsertSliceOp>(
|
||||
loc, sliceOp.source().getType(), results[resultIdx], sliceOp.source(),
|
||||
sliceOp.offsets(), sliceOp.sizes(), sliceOp.strides(),
|
||||
sliceOp.static_offsets(), sliceOp.static_sizes(),
|
||||
sliceOp.static_strides());
|
||||
loc, sliceOp.getSource().getType(), results[resultIdx],
|
||||
sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(),
|
||||
sliceOp.getStrides(), sliceOp.getStaticOffsets(),
|
||||
sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
|
||||
tensorResults.push_back(inserted);
|
||||
} else {
|
||||
tensorResults.push_back(results[resultIdx]);
|
||||
|
|
|
@ -121,7 +121,7 @@ static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
|
|||
if (!memRefType)
|
||||
return op.emitOpError("result must be a memref");
|
||||
|
||||
if (static_cast<int64_t>(op.dynamicSizes().size()) !=
|
||||
if (static_cast<int64_t>(op.getDynamicSizes().size()) !=
|
||||
memRefType.getNumDynamicDims())
|
||||
return op.emitOpError("dimension operand count does not equal memref "
|
||||
"dynamic dimension count");
|
||||
|
@ -129,10 +129,10 @@ static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
|
|||
unsigned numSymbols = 0;
|
||||
if (!memRefType.getLayout().isIdentity())
|
||||
numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
|
||||
if (op.symbolOperands().size() != numSymbols)
|
||||
if (op.getSymbolOperands().size() != numSymbols)
|
||||
return op.emitOpError("symbol operand count does not equal memref symbol "
|
||||
"count: expected ")
|
||||
<< numSymbols << ", got " << op.symbolOperands().size();
|
||||
<< numSymbols << ", got " << op.getSymbolOperands().size();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -158,7 +158,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
|
|||
PatternRewriter &rewriter) const override {
|
||||
// Check to see if any dimensions operands are constants. If so, we can
|
||||
// substitute and drop them.
|
||||
if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) {
|
||||
if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
|
||||
return matchPattern(operand, matchConstantIndex());
|
||||
}))
|
||||
return failure();
|
||||
|
@ -179,7 +179,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
|
|||
newShapeConstants.push_back(dimSize);
|
||||
continue;
|
||||
}
|
||||
auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos];
|
||||
auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
|
||||
auto *defOp = dynamicSize.getDefiningOp();
|
||||
if (auto constantIndexOp =
|
||||
dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
|
||||
|
@ -201,8 +201,8 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
|
|||
|
||||
// Create and insert the alloc op for the new memref.
|
||||
auto newAlloc = rewriter.create<AllocLikeOp>(
|
||||
alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(),
|
||||
alloc.alignmentAttr());
|
||||
alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
|
||||
alloc.getAlignmentAttr());
|
||||
// Insert a cast so we have the same type as the old alloc.
|
||||
auto resultCast =
|
||||
rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc);
|
||||
|
@ -221,7 +221,7 @@ struct SimplifyDeadAlloc : public OpRewritePattern<T> {
|
|||
PatternRewriter &rewriter) const override {
|
||||
if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
|
||||
if (auto storeOp = dyn_cast<StoreOp>(op))
|
||||
return storeOp.value() == alloc;
|
||||
return storeOp.getValue() == alloc;
|
||||
return !isa<DeallocOp>(op);
|
||||
}))
|
||||
return failure();
|
||||
|
@ -254,12 +254,12 @@ void AllocaScopeOp::print(OpAsmPrinter &p) {
|
|||
bool printBlockTerminators = false;
|
||||
|
||||
p << ' ';
|
||||
if (!results().empty()) {
|
||||
if (!getResults().empty()) {
|
||||
p << " -> (" << getResultTypes() << ")";
|
||||
printBlockTerminators = true;
|
||||
}
|
||||
p << ' ';
|
||||
p.printRegion(bodyRegion(),
|
||||
p.printRegion(getBodyRegion(),
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/printBlockTerminators);
|
||||
p.printOptionalAttrDict((*this)->getAttrs());
|
||||
|
@ -295,7 +295,7 @@ void AllocaScopeOp::getSuccessorRegions(
|
|||
return;
|
||||
}
|
||||
|
||||
regions.push_back(RegionSuccessor(&bodyRegion()));
|
||||
regions.push_back(RegionSuccessor(&getBodyRegion()));
|
||||
}
|
||||
|
||||
/// Given an operation, return whether this op is guaranteed to
|
||||
|
@ -467,7 +467,7 @@ void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AssumeAlignmentOp::verify() {
|
||||
if (!llvm::isPowerOf2_32(alignment()))
|
||||
if (!llvm::isPowerOf2_32(getAlignment()))
|
||||
return emitOpError("alignment must be power of 2");
|
||||
return success();
|
||||
}
|
||||
|
@ -514,7 +514,7 @@ LogicalResult AssumeAlignmentOp::verify() {
|
|||
/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
|
||||
/// ```
|
||||
bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
|
||||
MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
|
||||
MemRefType sourceType = castOp.getSource().getType().dyn_cast<MemRefType>();
|
||||
MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
|
||||
|
||||
// Requires ranked MemRefType.
|
||||
|
@ -652,30 +652,32 @@ struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
|
|||
bool modified = false;
|
||||
|
||||
// Check source.
|
||||
if (auto castOp = copyOp.source().getDefiningOp<CastOp>()) {
|
||||
auto fromType = castOp.source().getType().dyn_cast<MemRefType>();
|
||||
auto toType = castOp.source().getType().dyn_cast<MemRefType>();
|
||||
if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
|
||||
auto fromType = castOp.getSource().getType().dyn_cast<MemRefType>();
|
||||
auto toType = castOp.getSource().getType().dyn_cast<MemRefType>();
|
||||
|
||||
if (fromType && toType) {
|
||||
if (fromType.getShape() == toType.getShape() &&
|
||||
fromType.getElementType() == toType.getElementType()) {
|
||||
rewriter.updateRootInPlace(
|
||||
copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); });
|
||||
rewriter.updateRootInPlace(copyOp, [&] {
|
||||
copyOp.getSourceMutable().assign(castOp.getSource());
|
||||
});
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check target.
|
||||
if (auto castOp = copyOp.target().getDefiningOp<CastOp>()) {
|
||||
auto fromType = castOp.source().getType().dyn_cast<MemRefType>();
|
||||
auto toType = castOp.source().getType().dyn_cast<MemRefType>();
|
||||
if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
|
||||
auto fromType = castOp.getSource().getType().dyn_cast<MemRefType>();
|
||||
auto toType = castOp.getSource().getType().dyn_cast<MemRefType>();
|
||||
|
||||
if (fromType && toType) {
|
||||
if (fromType.getShape() == toType.getShape() &&
|
||||
fromType.getElementType() == toType.getElementType()) {
|
||||
rewriter.updateRootInPlace(
|
||||
copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); });
|
||||
rewriter.updateRootInPlace(copyOp, [&] {
|
||||
copyOp.getTargetMutable().assign(castOp.getSource());
|
||||
});
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
@ -691,7 +693,7 @@ struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(CopyOp copyOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (copyOp.source() != copyOp.target())
|
||||
if (copyOp.getSource() != copyOp.getTarget())
|
||||
return failure();
|
||||
|
||||
rewriter.eraseOp(copyOp);
|
||||
|
@ -748,7 +750,7 @@ void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
|
|||
}
|
||||
|
||||
Optional<int64_t> DimOp::getConstantIndex() {
|
||||
if (auto constantOp = index().getDefiningOp<arith::ConstantOp>())
|
||||
if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
|
||||
return constantOp.getValue().cast<IntegerAttr>().getInt();
|
||||
return {};
|
||||
}
|
||||
|
@ -760,7 +762,7 @@ LogicalResult DimOp::verify() {
|
|||
return success();
|
||||
|
||||
// Check that constant index is not knowingly out of range.
|
||||
auto type = source().getType();
|
||||
auto type = getSource().getType();
|
||||
if (auto memrefType = type.dyn_cast<MemRefType>()) {
|
||||
if (*index >= memrefType.getRank())
|
||||
return emitOpError("index is out of range");
|
||||
|
@ -875,7 +877,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
|||
return {};
|
||||
|
||||
// Folding for unranked types (UnrankedMemRefType) is not supported.
|
||||
auto memrefType = source().getType().dyn_cast<MemRefType>();
|
||||
auto memrefType = getSource().getType().dyn_cast<MemRefType>();
|
||||
if (!memrefType)
|
||||
return {};
|
||||
|
||||
|
@ -889,7 +891,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
|||
unsigned unsignedIndex = index.getValue().getZExtValue();
|
||||
|
||||
// Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
|
||||
Operation *definingOp = source().getDefiningOp();
|
||||
Operation *definingOp = getSource().getDefiningOp();
|
||||
|
||||
if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
|
||||
return *(alloc.getDynamicSizes().begin() +
|
||||
|
@ -944,7 +946,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(DimOp dim,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto reshape = dim.source().getDefiningOp<ReshapeOp>();
|
||||
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
|
||||
|
||||
if (!reshape)
|
||||
return failure();
|
||||
|
@ -953,7 +955,8 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
|
|||
// was not mutated.
|
||||
rewriter.setInsertionPointAfter(reshape);
|
||||
Location loc = dim.getLoc();
|
||||
Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index());
|
||||
Value load =
|
||||
rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
|
||||
if (load.getType() != dim.getType())
|
||||
load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
|
||||
rewriter.replaceOp(dim, load);
|
||||
|
@ -1151,7 +1154,7 @@ LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
|
|||
|
||||
LogicalResult DmaWaitOp::verify() {
|
||||
// Check that the number of tag indices matches the tagMemRef rank.
|
||||
unsigned numTagIndices = tagIndices().size();
|
||||
unsigned numTagIndices = getTagIndices().size();
|
||||
unsigned tagMemRefRank = getTagMemRefRank();
|
||||
if (numTagIndices != tagMemRefRank)
|
||||
return emitOpError() << "expected tagIndices to have the same number of "
|
||||
|
@ -1223,8 +1226,8 @@ ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
|
|||
}
|
||||
|
||||
void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
|
||||
p << ' ' << memref() << "[" << indices() << "] : " << memref().getType()
|
||||
<< ' ';
|
||||
p << ' ' << getMemref() << "[" << getIndices()
|
||||
<< "] : " << getMemref().getType() << ' ';
|
||||
p.printRegion(getRegion());
|
||||
p.printOptionalAttrDict((*this)->getAttrs());
|
||||
}
|
||||
|
@ -1235,7 +1238,7 @@ void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
|
|||
|
||||
LogicalResult AtomicYieldOp::verify() {
|
||||
Type parentType = (*this)->getParentOp()->getResultTypes().front();
|
||||
Type resultType = result().getType();
|
||||
Type resultType = getResult().getType();
|
||||
if (parentType != resultType)
|
||||
return emitOpError() << "types mismatch between yield op: " << resultType
|
||||
<< " and its parent: " << parentType;
|
||||
|
@ -1290,15 +1293,15 @@ parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
|
|||
}
|
||||
|
||||
LogicalResult GlobalOp::verify() {
|
||||
auto memrefType = type().dyn_cast<MemRefType>();
|
||||
auto memrefType = getType().dyn_cast<MemRefType>();
|
||||
if (!memrefType || !memrefType.hasStaticShape())
|
||||
return emitOpError("type should be static shaped memref, but got ")
|
||||
<< type();
|
||||
<< getType();
|
||||
|
||||
// Verify that the initial value, if present, is either a unit attribute or
|
||||
// an elements attribute.
|
||||
if (initial_value().hasValue()) {
|
||||
Attribute initValue = initial_value().getValue();
|
||||
if (getInitialValue().hasValue()) {
|
||||
Attribute initValue = getInitialValue().getValue();
|
||||
if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
|
||||
return emitOpError("initial value should be a unit or elements "
|
||||
"attribute, but got ")
|
||||
|
@ -1315,7 +1318,7 @@ LogicalResult GlobalOp::verify() {
|
|||
}
|
||||
}
|
||||
|
||||
if (Optional<uint64_t> alignAttr = alignment()) {
|
||||
if (Optional<uint64_t> alignAttr = getAlignment()) {
|
||||
uint64_t alignment = *alignAttr;
|
||||
|
||||
if (!llvm::isPowerOf2_64(alignment))
|
||||
|
@ -1328,8 +1331,8 @@ LogicalResult GlobalOp::verify() {
|
|||
}
|
||||
|
||||
ElementsAttr GlobalOp::getConstantInitValue() {
|
||||
auto initVal = initial_value();
|
||||
if (constant() && initVal.hasValue())
|
||||
auto initVal = getInitialValue();
|
||||
if (getConstant() && initVal.hasValue())
|
||||
return initVal.getValue().cast<ElementsAttr>();
|
||||
return {};
|
||||
}
|
||||
|
@ -1343,16 +1346,16 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|||
// Verify that the result type is same as the type of the referenced
|
||||
// memref.global op.
|
||||
auto global =
|
||||
symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
|
||||
symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
|
||||
if (!global)
|
||||
return emitOpError("'")
|
||||
<< name() << "' does not reference a valid global memref";
|
||||
<< getName() << "' does not reference a valid global memref";
|
||||
|
||||
Type resultType = result().getType();
|
||||
if (global.type() != resultType)
|
||||
Type resultType = getResult().getType();
|
||||
if (global.getType() != resultType)
|
||||
return emitOpError("result type ")
|
||||
<< resultType << " does not match type " << global.type()
|
||||
<< " of the global memref @" << name();
|
||||
<< resultType << " does not match type " << global.getType()
|
||||
<< " of the global memref @" << getName();
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1378,11 +1381,11 @@ OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PrefetchOp::print(OpAsmPrinter &p) {
|
||||
p << " " << memref() << '[';
|
||||
p.printOperands(indices());
|
||||
p << ']' << ", " << (isWrite() ? "write" : "read");
|
||||
p << ", locality<" << localityHint();
|
||||
p << ">, " << (isDataCache() ? "data" : "instr");
|
||||
p << " " << getMemref() << '[';
|
||||
p.printOperands(getIndices());
|
||||
p << ']' << ", " << (getIsWrite() ? "write" : "read");
|
||||
p << ", locality<" << getLocalityHint();
|
||||
p << ">, " << (getIsDataCache() ? "data" : "instr");
|
||||
p.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
/*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
|
||||
|
@ -1513,7 +1516,7 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
|
|||
// completed automatically, like we have for subview and extract_slice.
|
||||
LogicalResult ReinterpretCastOp::verify() {
|
||||
// The source and result memrefs should be in the same memory space.
|
||||
auto srcType = source().getType().cast<BaseMemRefType>();
|
||||
auto srcType = getSource().getType().cast<BaseMemRefType>();
|
||||
auto resultType = getType().cast<MemRefType>();
|
||||
if (srcType.getMemorySpace() != resultType.getMemorySpace())
|
||||
return emitError("different memory spaces specified for source type ")
|
||||
|
@ -1524,7 +1527,7 @@ LogicalResult ReinterpretCastOp::verify() {
|
|||
|
||||
// Match sizes in result memref type and in static_sizes attribute.
|
||||
for (auto &en : llvm::enumerate(llvm::zip(
|
||||
resultType.getShape(), extractFromI64ArrayAttr(static_sizes())))) {
|
||||
resultType.getShape(), extractFromI64ArrayAttr(getStaticSizes())))) {
|
||||
int64_t resultSize = std::get<0>(en.value());
|
||||
int64_t expectedSize = std::get<1>(en.value());
|
||||
if (!ShapedType::isDynamic(resultSize) &&
|
||||
|
@ -1544,7 +1547,7 @@ LogicalResult ReinterpretCastOp::verify() {
|
|||
<< resultType;
|
||||
|
||||
// Match offset in result memref type and in static_offsets attribute.
|
||||
int64_t expectedOffset = extractFromI64ArrayAttr(static_offsets()).front();
|
||||
int64_t expectedOffset = extractFromI64ArrayAttr(getStaticOffsets()).front();
|
||||
if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
|
||||
!ShapedType::isDynamicStrideOrOffset(expectedOffset) &&
|
||||
resultOffset != expectedOffset)
|
||||
|
@ -1553,7 +1556,7 @@ LogicalResult ReinterpretCastOp::verify() {
|
|||
|
||||
// Match strides in result memref type and in static_strides attribute.
|
||||
for (auto &en : llvm::enumerate(llvm::zip(
|
||||
resultStrides, extractFromI64ArrayAttr(static_strides())))) {
|
||||
resultStrides, extractFromI64ArrayAttr(getStaticStrides())))) {
|
||||
int64_t resultStride = std::get<0>(en.value());
|
||||
int64_t expectedStride = std::get<1>(en.value());
|
||||
if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
|
||||
|
@ -1568,15 +1571,15 @@ LogicalResult ReinterpretCastOp::verify() {
|
|||
}
|
||||
|
||||
OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
|
||||
Value src = source();
|
||||
Value src = getSource();
|
||||
auto getPrevSrc = [&]() -> Value {
|
||||
// reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
|
||||
if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
|
||||
return prev.source();
|
||||
return prev.getSource();
|
||||
|
||||
// reinterpret_cast(cast(x)) -> reinterpret_cast(x).
|
||||
if (auto prev = src.getDefiningOp<CastOp>())
|
||||
return prev.source();
|
||||
return prev.getSource();
|
||||
|
||||
// reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
|
||||
// are 0.
|
||||
|
@ -1584,13 +1587,13 @@ OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
|
|||
if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
|
||||
return isConstantIntValue(val, 0);
|
||||
}))
|
||||
return prev.source();
|
||||
return prev.getSource();
|
||||
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
if (auto prevSrc = getPrevSrc()) {
|
||||
sourceMutable().assign(prevSrc);
|
||||
getSourceMutable().assign(prevSrc);
|
||||
return getResult();
|
||||
}
|
||||
|
||||
|
@ -1998,10 +2001,10 @@ public:
|
|||
|
||||
if (newResultType == op.getResultType()) {
|
||||
rewriter.updateRootInPlace(
|
||||
op, [&]() { op.srcMutable().assign(cast.source()); });
|
||||
op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
|
||||
} else {
|
||||
Value newOp = rewriter.create<CollapseShapeOp>(
|
||||
op->getLoc(), cast.source(), op.getReassociationIndices());
|
||||
op->getLoc(), cast.getSource(), op.getReassociationIndices());
|
||||
rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
|
||||
}
|
||||
return success();
|
||||
|
@ -2028,8 +2031,8 @@ OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ReshapeOp::verify() {
|
||||
Type operandType = source().getType();
|
||||
Type resultType = result().getType();
|
||||
Type operandType = getSource().getType();
|
||||
Type resultType = getResult().getType();
|
||||
|
||||
Type operandElementType = operandType.cast<ShapedType>().getElementType();
|
||||
Type resultElementType = resultType.cast<ShapedType>().getElementType();
|
||||
|
@ -2041,7 +2044,7 @@ LogicalResult ReshapeOp::verify() {
|
|||
if (!operandMemRefType.getLayout().isIdentity())
|
||||
return emitOpError("source memref type should have identity affine map");
|
||||
|
||||
int64_t shapeSize = shape().getType().cast<MemRefType>().getDimSize(0);
|
||||
int64_t shapeSize = getShape().getType().cast<MemRefType>().getDimSize(0);
|
||||
auto resultMemRefType = resultType.dyn_cast<MemRefType>();
|
||||
if (resultMemRefType) {
|
||||
if (!resultMemRefType.getLayout().isIdentity())
|
||||
|
@ -2296,7 +2299,7 @@ void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
|
|||
}
|
||||
|
||||
/// For ViewLikeOpInterface.
|
||||
Value SubViewOp::getViewSource() { return source(); }
|
||||
Value SubViewOp::getViewSource() { return getSource(); }
|
||||
|
||||
/// Return true if t1 and t2 have equal offsets (both dynamic or of same
|
||||
/// static value).
|
||||
|
@ -2381,9 +2384,9 @@ LogicalResult SubViewOp::verify() {
|
|||
|
||||
// Verify result type against inferred type.
|
||||
auto expectedType = SubViewOp::inferResultType(
|
||||
baseType, extractFromI64ArrayAttr(static_offsets()),
|
||||
extractFromI64ArrayAttr(static_sizes()),
|
||||
extractFromI64ArrayAttr(static_strides()));
|
||||
baseType, extractFromI64ArrayAttr(getStaticOffsets()),
|
||||
extractFromI64ArrayAttr(getStaticSizes()),
|
||||
extractFromI64ArrayAttr(getStaticStrides()));
|
||||
|
||||
auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
|
||||
subViewType, getMixedSizes());
|
||||
|
@ -2536,7 +2539,7 @@ public:
|
|||
}))
|
||||
return failure();
|
||||
|
||||
auto castOp = subViewOp.source().getDefiningOp<CastOp>();
|
||||
auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
|
||||
if (!castOp)
|
||||
return failure();
|
||||
|
||||
|
@ -2549,16 +2552,17 @@ public:
|
|||
// if the operation is rank-reducing.
|
||||
auto resultType = getCanonicalSubViewResultType(
|
||||
subViewOp.getType(), subViewOp.getSourceType(),
|
||||
castOp.source().getType().cast<MemRefType>(),
|
||||
castOp.getSource().getType().cast<MemRefType>(),
|
||||
subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
|
||||
subViewOp.getMixedStrides());
|
||||
if (!resultType)
|
||||
return failure();
|
||||
|
||||
Value newSubView = rewriter.create<SubViewOp>(
|
||||
subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
|
||||
subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
|
||||
subViewOp.static_sizes(), subViewOp.static_strides());
|
||||
subViewOp.getLoc(), resultType, castOp.getSource(),
|
||||
subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
|
||||
subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
|
||||
subViewOp.getStaticStrides());
|
||||
rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
|
||||
newSubView);
|
||||
return success();
|
||||
|
@ -2576,11 +2580,11 @@ public:
|
|||
if (!isTrivialSubViewOp(subViewOp))
|
||||
return failure();
|
||||
if (subViewOp.getSourceType() == subViewOp.getType()) {
|
||||
rewriter.replaceOp(subViewOp, subViewOp.source());
|
||||
rewriter.replaceOp(subViewOp, subViewOp.getSource());
|
||||
return success();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
|
||||
subViewOp.source());
|
||||
subViewOp.getSource());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -2614,7 +2618,7 @@ void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
|
||||
OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto resultShapedType = getResult().getType().cast<ShapedType>();
|
||||
auto sourceShapedType = source().getType().cast<ShapedType>();
|
||||
auto sourceShapedType = getSource().getType().cast<ShapedType>();
|
||||
|
||||
if (resultShapedType.hasStaticShape() &&
|
||||
resultShapedType == sourceShapedType) {
|
||||
|
@ -2669,9 +2673,9 @@ void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
|
|||
|
||||
// transpose $in $permutation attr-dict : type($in) `to` type(results)
|
||||
void TransposeOp::print(OpAsmPrinter &p) {
|
||||
p << " " << in() << " " << permutation();
|
||||
p << " " << getIn() << " " << getPermutation();
|
||||
p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
|
||||
p << " : " << in().getType() << " to " << getType();
|
||||
p << " : " << getIn().getType() << " to " << getType();
|
||||
}
|
||||
|
||||
ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
|
@ -2692,14 +2696,14 @@ ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
}
|
||||
|
||||
LogicalResult TransposeOp::verify() {
|
||||
if (!permutation().isPermutation())
|
||||
if (!getPermutation().isPermutation())
|
||||
return emitOpError("expected a permutation map");
|
||||
if (permutation().getNumDims() != getShapedType().getRank())
|
||||
if (getPermutation().getNumDims() != getShapedType().getRank())
|
||||
return emitOpError("expected a permutation map of same rank as the input");
|
||||
|
||||
auto srcType = in().getType().cast<MemRefType>();
|
||||
auto srcType = getIn().getType().cast<MemRefType>();
|
||||
auto dstType = getType().cast<MemRefType>();
|
||||
auto transposedType = inferTransposeResultType(srcType, permutation());
|
||||
auto transposedType = inferTransposeResultType(srcType, getPermutation());
|
||||
if (dstType != transposedType)
|
||||
return emitOpError("output type ")
|
||||
<< dstType << " does not match transposed input type " << srcType
|
||||
|
@ -2737,13 +2741,13 @@ LogicalResult ViewOp::verify() {
|
|||
|
||||
// Verify that we have the correct number of sizes for the result type.
|
||||
unsigned numDynamicDims = viewType.getNumDynamicDims();
|
||||
if (sizes().size() != numDynamicDims)
|
||||
if (getSizes().size() != numDynamicDims)
|
||||
return emitError("incorrect number of size operands for type ") << viewType;
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
Value ViewOp::getViewSource() { return source(); }
|
||||
Value ViewOp::getViewSource() { return getSource(); }
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -2785,7 +2789,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
|
|||
newShapeConstants.push_back(dimSize);
|
||||
continue;
|
||||
}
|
||||
auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
|
||||
auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
|
||||
if (auto constantIndexOp =
|
||||
dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
|
||||
// Dynamic shape dimension will be folded.
|
||||
|
@ -2793,7 +2797,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
|
|||
} else {
|
||||
// Dynamic shape dimension not folded; copy operand from old memref.
|
||||
newShapeConstants.push_back(dimSize);
|
||||
newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
|
||||
newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
|
||||
}
|
||||
dynamicDimPos++;
|
||||
}
|
||||
|
@ -2806,9 +2810,9 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
|
|||
return failure();
|
||||
|
||||
// Create new ViewOp.
|
||||
auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
|
||||
viewOp.getOperand(0),
|
||||
viewOp.byte_shift(), newOperands);
|
||||
auto newViewOp = rewriter.create<ViewOp>(
|
||||
viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
|
||||
viewOp.getByteShift(), newOperands);
|
||||
// Insert a cast so we have the same type as the old memref type.
|
||||
rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
|
||||
return success();
|
||||
|
@ -2829,7 +2833,8 @@ struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
|
|||
if (!allocOp)
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
|
||||
viewOp.byte_shift(), viewOp.sizes());
|
||||
viewOp.getByteShift(),
|
||||
viewOp.getSizes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -2849,14 +2854,14 @@ LogicalResult AtomicRMWOp::verify() {
|
|||
if (getMemRefType().getRank() != getNumOperands() - 2)
|
||||
return emitOpError(
|
||||
"expects the number of subscripts to be equal to memref rank");
|
||||
switch (kind()) {
|
||||
switch (getKind()) {
|
||||
case arith::AtomicRMWKind::addf:
|
||||
case arith::AtomicRMWKind::maxf:
|
||||
case arith::AtomicRMWKind::minf:
|
||||
case arith::AtomicRMWKind::mulf:
|
||||
if (!value().getType().isa<FloatType>())
|
||||
if (!getValue().getType().isa<FloatType>())
|
||||
return emitOpError() << "with kind '"
|
||||
<< arith::stringifyAtomicRMWKind(kind())
|
||||
<< arith::stringifyAtomicRMWKind(getKind())
|
||||
<< "' expects a floating-point type";
|
||||
break;
|
||||
case arith::AtomicRMWKind::addi:
|
||||
|
@ -2867,9 +2872,9 @@ LogicalResult AtomicRMWOp::verify() {
|
|||
case arith::AtomicRMWKind::muli:
|
||||
case arith::AtomicRMWKind::ori:
|
||||
case arith::AtomicRMWKind::andi:
|
||||
if (!value().getType().isa<IntegerType>())
|
||||
if (!getValue().getType().isa<IntegerType>())
|
||||
return emitOpError() << "with kind '"
|
||||
<< arith::stringifyAtomicRMWKind(kind())
|
||||
<< arith::stringifyAtomicRMWKind(getKind())
|
||||
<< "' expects an integer type";
|
||||
break;
|
||||
default:
|
||||
|
@ -2880,7 +2885,7 @@ LogicalResult AtomicRMWOp::verify() {
|
|||
|
||||
OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) {
|
||||
/// atomicrmw(memrefcast) -> atomicrmw
|
||||
if (succeeded(foldMemRefCast(*this, value())))
|
||||
if (succeeded(foldMemRefCast(*this, getValue())))
|
||||
return getResult();
|
||||
return OpFoldResult();
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
|
|||
// produces the input of the op we're rewriting (for 'SubViewOp' the input
|
||||
// is called the "source" value). We can only combine them if both 'op' and
|
||||
// 'sourceOp' are 'SubViewOp'.
|
||||
auto sourceOp = op.source().getDefiningOp<memref::SubViewOp>();
|
||||
auto sourceOp = op.getSource().getDefiningOp<memref::SubViewOp>();
|
||||
if (!sourceOp)
|
||||
return failure();
|
||||
|
||||
|
@ -119,7 +119,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
|
|||
|
||||
// This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any
|
||||
// uses it can be removed by a (separate) dead code elimination pass.
|
||||
rewriter.replaceOpWithNewOp<memref::SubViewOp>(op, sourceOp.source(),
|
||||
rewriter.replaceOpWithNewOp<memref::SubViewOp>(op, sourceOp.getSource(),
|
||||
offsets, sizes, strides);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ public:
|
|||
LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
arith::CmpFPredicate predicate;
|
||||
switch (op.kind()) {
|
||||
switch (op.getKind()) {
|
||||
case arith::AtomicRMWKind::maxf:
|
||||
predicate = arith::CmpFPredicate::OGT;
|
||||
break;
|
||||
|
@ -59,12 +59,12 @@ public:
|
|||
|
||||
auto loc = op.getLoc();
|
||||
auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
|
||||
loc, op.memref(), op.indices());
|
||||
loc, op.getMemref(), op.getIndices());
|
||||
OpBuilder bodyBuilder =
|
||||
OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());
|
||||
|
||||
Value lhs = genericOp.getCurrentValue();
|
||||
Value rhs = op.value();
|
||||
Value rhs = op.getValue();
|
||||
Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
|
||||
Value select = bodyBuilder.create<arith::SelectOp>(loc, cmp, lhs, rhs);
|
||||
bodyBuilder.create<memref::AtomicYieldOp>(loc, select);
|
||||
|
@ -82,7 +82,7 @@ public:
|
|||
|
||||
LogicalResult matchAndRewrite(memref::ReshapeOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto shapeType = op.shape().getType().cast<MemRefType>();
|
||||
auto shapeType = op.getShape().getType().cast<MemRefType>();
|
||||
if (!shapeType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
|
@ -98,7 +98,7 @@ public:
|
|||
// Load dynamic sizes from the shape input, use constants for static dims.
|
||||
if (op.getType().isDynamicDim(i)) {
|
||||
Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
|
||||
size = rewriter.create<memref::LoadOp>(loc, op.shape(), index);
|
||||
size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index);
|
||||
if (!size.getType().isa<IndexType>())
|
||||
size = rewriter.create<arith::IndexCastOp>(
|
||||
loc, rewriter.getIndexType(), size);
|
||||
|
@ -113,7 +113,7 @@ public:
|
|||
stride = rewriter.create<arith::MulIOp>(loc, stride, size);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
|
||||
op, op.getType(), op.source(), /*offset=*/rewriter.getIndexAttr(0),
|
||||
op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0),
|
||||
sizes, strides);
|
||||
return success();
|
||||
}
|
||||
|
@ -130,11 +130,11 @@ struct ExpandOpsPass : public ExpandOpsBase<ExpandOpsPass> {
|
|||
target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect>();
|
||||
target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
|
||||
[](memref::AtomicRMWOp op) {
|
||||
return op.kind() != arith::AtomicRMWKind::maxf &&
|
||||
op.kind() != arith::AtomicRMWKind::minf;
|
||||
return op.getKind() != arith::AtomicRMWKind::maxf &&
|
||||
op.getKind() != arith::AtomicRMWKind::minf;
|
||||
});
|
||||
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
|
||||
return !op.shape().getType().cast<MemRefType>().hasStaticShape();
|
||||
return !op.getShape().getType().cast<MemRefType>().hasStaticShape();
|
||||
});
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -162,7 +162,7 @@ template <typename LoadOpTy>
|
|||
void LoadOpOfSubViewFolder<LoadOpTy>::replaceOp(
|
||||
LoadOpTy loadOp, memref::SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
|
||||
PatternRewriter &rewriter) const {
|
||||
rewriter.replaceOpWithNewOp<LoadOpTy>(loadOp, subViewOp.source(),
|
||||
rewriter.replaceOpWithNewOp<LoadOpTy>(loadOp, subViewOp.getSource(),
|
||||
sourceIndices);
|
||||
}
|
||||
|
||||
|
@ -174,7 +174,7 @@ void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
|
|||
if (transferReadOp.getTransferRank() == 0)
|
||||
return;
|
||||
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
|
||||
transferReadOp, transferReadOp.getVectorType(), subViewOp.source(),
|
||||
transferReadOp, transferReadOp.getVectorType(), subViewOp.getSource(),
|
||||
sourceIndices,
|
||||
getPermutationMapAttr(rewriter.getContext(), subViewOp,
|
||||
transferReadOp.getPermutationMap()),
|
||||
|
@ -187,7 +187,7 @@ void StoreOpOfSubViewFolder<StoreOpTy>::replaceOp(
|
|||
StoreOpTy storeOp, memref::SubViewOp subViewOp,
|
||||
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
|
||||
rewriter.replaceOpWithNewOp<StoreOpTy>(storeOp, storeOp.getValue(),
|
||||
subViewOp.source(), sourceIndices);
|
||||
subViewOp.getSource(), sourceIndices);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -198,7 +198,7 @@ void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
|
|||
if (transferWriteOp.getTransferRank() == 0)
|
||||
return;
|
||||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
||||
transferWriteOp, transferWriteOp.getVector(), subViewOp.source(),
|
||||
transferWriteOp, transferWriteOp.getVector(), subViewOp.getSource(),
|
||||
sourceIndices,
|
||||
getPermutationMapAttr(rewriter.getContext(), subViewOp,
|
||||
transferWriteOp.getPermutationMap()),
|
||||
|
|
|
@ -23,7 +23,7 @@ static bool overrideBuffer(Operation *op, Value buffer) {
|
|||
auto copyOp = dyn_cast<memref::CopyOp>(op);
|
||||
if (!copyOp)
|
||||
return false;
|
||||
return copyOp.target() == buffer;
|
||||
return copyOp.getTarget() == buffer;
|
||||
}
|
||||
|
||||
/// Replace the uses of `oldOp` with the given `val` and for subview uses
|
||||
|
@ -45,9 +45,9 @@ static void replaceUsesAndPropagateType(Operation *oldOp, Value val,
|
|||
builder.setInsertionPoint(subviewUse);
|
||||
Type newType = memref::SubViewOp::inferRankReducedResultType(
|
||||
subviewUse.getType().getShape(), val.getType().cast<MemRefType>(),
|
||||
extractFromI64ArrayAttr(subviewUse.static_offsets()),
|
||||
extractFromI64ArrayAttr(subviewUse.static_sizes()),
|
||||
extractFromI64ArrayAttr(subviewUse.static_strides()));
|
||||
extractFromI64ArrayAttr(subviewUse.getStaticOffsets()),
|
||||
extractFromI64ArrayAttr(subviewUse.getStaticSizes()),
|
||||
extractFromI64ArrayAttr(subviewUse.getStaticStrides()));
|
||||
Value newSubview = builder.create<memref::SubViewOp>(
|
||||
subviewUse->getLoc(), newType.cast<MemRefType>(), val,
|
||||
subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
|
||||
|
|
|
@ -105,9 +105,9 @@ Operation::operand_range getIndices(Operation *op) {
|
|||
if (auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op))
|
||||
return copyOp.getDstIndices();
|
||||
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
|
||||
return loadOp.indices();
|
||||
return loadOp.getIndices();
|
||||
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
|
||||
return storeOp.indices();
|
||||
return storeOp.getIndices();
|
||||
if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
|
||||
return vectorReadOp.getIndices();
|
||||
if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
|
||||
|
@ -121,9 +121,9 @@ void setIndices(Operation *op, ArrayRef<Value> indices) {
|
|||
if (auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op))
|
||||
return copyOp.getDstIndicesMutable().assign(indices);
|
||||
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
|
||||
return loadOp.indicesMutable().assign(indices);
|
||||
return loadOp.getIndicesMutable().assign(indices);
|
||||
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
|
||||
return storeOp.indicesMutable().assign(indices);
|
||||
return storeOp.getIndicesMutable().assign(indices);
|
||||
if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
|
||||
return vectorReadOp.getIndicesMutable().assign(indices);
|
||||
if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
|
||||
|
@ -250,14 +250,17 @@ public:
|
|||
Operation *op = getOperation();
|
||||
SmallVector<memref::AllocOp> shmAllocOps;
|
||||
op->walk([&](memref::AllocOp allocOp) {
|
||||
if (allocOp.memref().getType().cast<MemRefType>().getMemorySpaceAsInt() !=
|
||||
if (allocOp.getMemref()
|
||||
.getType()
|
||||
.cast<MemRefType>()
|
||||
.getMemorySpaceAsInt() !=
|
||||
gpu::GPUDialect::getWorkgroupAddressSpace())
|
||||
return;
|
||||
shmAllocOps.push_back(allocOp);
|
||||
});
|
||||
for (auto allocOp : shmAllocOps) {
|
||||
if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
|
||||
allocOp.memref())))
|
||||
allocOp.getMemref())))
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -55,7 +55,7 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
|
|||
const BlockAndValueMapping &firstToSecondPloopIndices) {
|
||||
DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
|
||||
firstPloop.getBody()->walk([&](memref::StoreOp store) {
|
||||
bufferStores[store.getMemRef()].push_back(store.indices());
|
||||
bufferStores[store.getMemRef()].push_back(store.getIndices());
|
||||
});
|
||||
auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
|
||||
// Stop if the memref is defined in secondPloop body. Careful alias analysis
|
||||
|
@ -75,7 +75,7 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
|
|||
// Check that the load indices of secondPloop coincide with store indices of
|
||||
// firstPloop for the same memrefs.
|
||||
auto storeIndices = write->second.front();
|
||||
auto loadIndices = load.indices();
|
||||
auto loadIndices = load.getIndices();
|
||||
if (storeIndices.size() != loadIndices.size())
|
||||
return WalkResult::interrupt();
|
||||
for (int i = 0, e = storeIndices.size(); i < e; ++i) {
|
||||
|
|
|
@ -108,18 +108,18 @@ struct SparseTensorConversionPass
|
|||
return converter.isLegal(op.getOperandTypes());
|
||||
});
|
||||
target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
|
||||
return converter.isLegal(op.source().getType()) &&
|
||||
converter.isLegal(op.dest().getType());
|
||||
return converter.isLegal(op.getSource().getType()) &&
|
||||
converter.isLegal(op.getDest().getType());
|
||||
});
|
||||
target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
|
||||
[&](tensor::ExpandShapeOp op) {
|
||||
return converter.isLegal(op.src().getType()) &&
|
||||
converter.isLegal(op.result().getType());
|
||||
return converter.isLegal(op.getSrc().getType()) &&
|
||||
converter.isLegal(op.getResult().getType());
|
||||
});
|
||||
target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
|
||||
[&](tensor::CollapseShapeOp op) {
|
||||
return converter.isLegal(op.src().getType()) &&
|
||||
converter.isLegal(op.result().getType());
|
||||
return converter.isLegal(op.getSrc().getType()) &&
|
||||
converter.isLegal(op.getResult().getType());
|
||||
});
|
||||
target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
|
||||
[&](bufferization::AllocTensorOp op) {
|
||||
|
|
|
@ -1840,26 +1840,26 @@ public:
|
|||
LogicalResult matchAndRewrite(tensor::ExpandShapeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto encDst = getSparseTensorEncoding(op.result().getType());
|
||||
auto encSrc = getSparseTensorEncoding(op.src().getType());
|
||||
auto encDst = getSparseTensorEncoding(op.getResult().getType());
|
||||
auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
|
||||
// Since a pure dense expansion is very cheap (change of view), for
|
||||
// sparse2dense or dense2sparse, we can simply unfuse a sparse
|
||||
// conversion from the actual expansion operation itself.
|
||||
if (encDst && encSrc) {
|
||||
return failure(); // TODO: implement sparse2sparse
|
||||
} else if (encSrc) {
|
||||
RankedTensorType rtp = op.src().getType().cast<RankedTensorType>();
|
||||
RankedTensorType rtp = op.getSrc().getType().cast<RankedTensorType>();
|
||||
auto denseTp =
|
||||
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
|
||||
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.src());
|
||||
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
|
||||
op->setOperand(0, convert);
|
||||
return success();
|
||||
} else if (encDst) {
|
||||
RankedTensorType rtp = op.result().getType().cast<RankedTensorType>();
|
||||
RankedTensorType rtp = op.getResult().getType().cast<RankedTensorType>();
|
||||
auto denseTp =
|
||||
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
|
||||
auto reshape = rewriter.create<tensor::ExpandShapeOp>(
|
||||
loc, denseTp, op.src(), op.getReassociation());
|
||||
loc, denseTp, op.getSrc(), op.getReassociation());
|
||||
Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
|
||||
rewriter.replaceOp(op, convert);
|
||||
return success();
|
||||
|
@ -1877,26 +1877,26 @@ public:
|
|||
LogicalResult matchAndRewrite(tensor::CollapseShapeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto encDst = getSparseTensorEncoding(op.result().getType());
|
||||
auto encSrc = getSparseTensorEncoding(op.src().getType());
|
||||
auto encDst = getSparseTensorEncoding(op.getResult().getType());
|
||||
auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
|
||||
// Since a pure dense collapse is very cheap (change of view), for
|
||||
// sparse2dense or dense2sparse, we can simply unfuse a sparse
|
||||
// conversion from the actual collapse operation itself.
|
||||
if (encDst && encSrc) {
|
||||
return failure(); // TODO: implement sparse2sparse
|
||||
} else if (encSrc) {
|
||||
RankedTensorType rtp = op.src().getType().cast<RankedTensorType>();
|
||||
RankedTensorType rtp = op.getSrc().getType().cast<RankedTensorType>();
|
||||
auto denseTp =
|
||||
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
|
||||
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.src());
|
||||
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
|
||||
op->setOperand(0, convert);
|
||||
return success();
|
||||
} else if (encDst) {
|
||||
RankedTensorType rtp = op.result().getType().cast<RankedTensorType>();
|
||||
RankedTensorType rtp = op.getResult().getType().cast<RankedTensorType>();
|
||||
auto denseTp =
|
||||
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
|
||||
auto reshape = rewriter.create<tensor::CollapseShapeOp>(
|
||||
loc, denseTp, op.src(), op.getReassociation());
|
||||
loc, denseTp, op.getSrc(), op.getReassociation());
|
||||
Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
|
||||
rewriter.replaceOp(op, convert);
|
||||
return success();
|
||||
|
|
|
@ -200,7 +200,7 @@ struct DimOpInterface
|
|||
if (failed(v))
|
||||
return failure();
|
||||
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
|
||||
dimOp.index());
|
||||
dimOp.getIndex());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -332,7 +332,7 @@ struct ExtractOpInterface
|
|||
if (failed(srcMemref))
|
||||
return failure();
|
||||
replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
|
||||
extractOp.indices());
|
||||
extractOp.getIndices());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue