forked from OSchip/llvm-project
[mlir][arith][bufferize] Fix tensors with different layouts after bufferization
Insert a cast if the two tensors with identical layout (that are passed to `arith.select`) have different layout maps after bufferization. Differential Revision: https://reviews.llvm.org/D123321
This commit is contained in:
parent
5626bd4289
commit
8b09141909
|
@ -129,6 +129,7 @@ struct SelectOpInterface
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
BufferizationState &state) const {
|
||||||
auto selectOp = cast<arith::SelectOp>(op);
|
auto selectOp = cast<arith::SelectOp>(op);
|
||||||
|
Location loc = selectOp.getLoc();
|
||||||
|
|
||||||
// `getBuffer` introduces copies if an OpOperand bufferizes out-of-place.
|
// `getBuffer` introduces copies if an OpOperand bufferizes out-of-place.
|
||||||
// TODO: It would be more efficient to copy the result of the `select` op
|
// TODO: It would be more efficient to copy the result of the `select` op
|
||||||
|
@ -139,6 +140,26 @@ struct SelectOpInterface
|
||||||
*state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/);
|
*state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/);
|
||||||
Value falseBuffer =
|
Value falseBuffer =
|
||||||
*state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/);
|
*state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/);
|
||||||
|
|
||||||
|
// The "true" and the "false" operands must have the same type. If the
|
||||||
|
// buffers have different types, they differ only in their layout map. Cast
|
||||||
|
// both of them to the most dynamic MemRef type.
|
||||||
|
if (trueBuffer.getType() != falseBuffer.getType()) {
|
||||||
|
auto trueType = trueBuffer.getType().cast<MemRefType>();
|
||||||
|
auto tensorType = selectOp.getTrueValue().getType().cast<TensorType>();
|
||||||
|
int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
|
||||||
|
SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
|
||||||
|
ShapedType::kDynamicStrideOrOffset);
|
||||||
|
AffineMap stridedLayout = makeStridedLinearLayoutMap(
|
||||||
|
dynamicStrides, dynamicOffset, op->getContext());
|
||||||
|
BaseMemRefType castedType = bufferization::getMemRefType(
|
||||||
|
tensorType, state.getOptions(), AffineMapAttr::get(stridedLayout),
|
||||||
|
trueType.getMemorySpace());
|
||||||
|
trueBuffer = rewriter.create<memref::CastOp>(loc, castedType, trueBuffer);
|
||||||
|
falseBuffer =
|
||||||
|
rewriter.create<memref::CastOp>(loc, castedType, falseBuffer);
|
||||||
|
}
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<arith::SelectOp>(
|
replaceOpWithNewBufferizedOp<arith::SelectOp>(
|
||||||
rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
|
rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -105,4 +105,18 @@ func @copy_deallocated() -> tensor<10xf32> {
|
||||||
return %0 : tensor<10xf32>
|
return %0 : tensor<10xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @select_different_tensors(
|
||||||
|
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
|
||||||
|
func @select_different_tensors(%t: tensor<?xf32>, %sz: index, %c: i1) -> tensor<?xf32> {
|
||||||
|
// CHECK-DAG: %[[m:.*]] = bufferization.to_memref %[[t]] : memref<?xf32, #{{.*}}>
|
||||||
|
// CHECK-DAG: %[[alloc:.*]] = memref.alloc(%{{.*}}) {{.*}} : memref<?xf32>
|
||||||
|
%0 = linalg.init_tensor [%sz] : tensor<?xf32>
|
||||||
|
|
||||||
|
// A cast must be inserted because %t and %0 have different memref types.
|
||||||
|
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]] : memref<?xf32> to memref<?xf32, #{{.*}}>
|
||||||
|
// CHECK: arith.select %{{.*}}, %[[casted]], %[[m]]
|
||||||
|
%1 = arith.select %c, %0, %t : tensor<?xf32>
|
||||||
|
return %1 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue