[mlir] Fix bad rebase landed in acb69f3b7c.

Differential Revision: https://reviews.llvm.org/D92265
This commit is contained in:
Christian Sigg 2020-11-28 13:46:43 +01:00
parent acb69f3b7c
commit e9e45b3887
1 changed files with 16 additions and 17 deletions
mlir/lib/Conversion/GPUCommon

View File

@ -177,7 +177,7 @@ public:
private:
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(gpu::AllocOp allocOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -191,7 +191,7 @@ public:
private:
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(gpu::DeallocOp deallocOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -343,18 +343,16 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
}
LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
gpu::AllocOp allocOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto allocOp = cast<gpu::AllocOp>(op);
MemRefType memRefType = allocOp.getType();
if (failed(areAllLLVMTypes(op, operands, rewriter)) ||
if (failed(areAllLLVMTypes(allocOp, operands, rewriter)) ||
!isSupportedMemRefType(memRefType) ||
failed(
isAsyncWithOneDependency(rewriter, cast<gpu::AsyncOpInterface>(op))))
failed(isAsyncWithOneDependency(rewriter, allocOp)))
return failure();
auto loc = op->getLoc();
auto loc = allocOp.getLoc();
// Get shape of the memref as values: static sizes are constant
// values and dynamic sizes are passed to 'alloc' as operands.
@ -367,7 +365,8 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
// Allocate the underlying buffer and store a pointer to it in the MemRef
// descriptor.
Type elementPtrType = this->getElementPtrType(memRefType);
auto adaptor = gpu::AllocOpAdaptor(operands, op->getAttrDictionary());
auto adaptor = gpu::AllocOpAdaptor(
operands, allocOp.getOperation()->getAttrDictionary());
auto stream = adaptor.asyncDependencies().front();
Value allocatedPtr =
allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0);
@ -381,29 +380,29 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
auto memRefDescriptor = this->createMemRefDescriptor(
loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
rewriter.replaceOp(op, {memRefDescriptor, stream});
rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
return success();
}
LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
gpu::DeallocOp deallocOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, operands, rewriter)) ||
failed(
isAsyncWithOneDependency(rewriter, cast<gpu::AsyncOpInterface>(op))))
if (failed(areAllLLVMTypes(deallocOp, operands, rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, deallocOp)))
return failure();
Location loc = op->getLoc();
Location loc = deallocOp.getLoc();
auto adaptor = gpu::DeallocOpAdaptor(operands, op->getAttrDictionary());
auto adaptor = gpu::DeallocOpAdaptor(
operands, deallocOp.getOperation()->getAttrDictionary());
Value pointer =
MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc);
auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer);
Value stream = adaptor.asyncDependencies().front();
deallocCallBuilder.create(loc, rewriter, {casted, stream});
rewriter.replaceOp(op, {stream});
rewriter.replaceOp(deallocOp, {stream});
return success();
}