[mlir][linalg][bufferize] Support arith.index_cast bufferization

This is in preparation of switching `-tensor-constant-bufferize` and `-arith-bufferize` to BufferizableOpInterface-based implementations.

Differential Revision: https://reviews.llvm.org/D118324
This commit is contained in:
Matthias Springer 2022-01-27 19:37:58 +09:00
parent 3b259a6842
commit dbd1bbced9
2 changed files with 61 additions and 0 deletions

View File

@ -57,6 +57,49 @@ struct ConstantOpInterface
}
};
struct IndexCastOpInterface
: public BufferizableOpInterface::ExternalModel<IndexCastOpInterface,
arith::IndexCastOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
return false;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
return op->getResult(0);
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationState &state) const {
auto castOp = cast<arith::IndexCastOp>(op);
Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/);
auto sourceType = source.getType().cast<BaseMemRefType>();
// Result type should have same layout and address space as the source type.
MemRefLayoutAttrInterface layout = {};
if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>())
layout = rankedMemRefType.getLayout();
Type resultType =
getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(),
layout, sourceType.getMemorySpace());
replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, source,
resultType);
return success();
}
};
} // namespace arith_ext
} // namespace comprehensive_bufferize
} // namespace linalg
@ -65,4 +108,6 @@ struct ConstantOpInterface
void mlir::linalg::comprehensive_bufferize::arith_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>();
registry
.addOpInterface<arith::IndexCastOp, arith_ext::IndexCastOpInterface>();
}

View File

@ -96,3 +96,19 @@ func @rank_reducing(
}
return %5: tensor<?x1x6x8xf32>
}
// -----
// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0)>
// CHECK-LABEL: func @index_cast(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<i32>, %[[SCALAR:.*]]: i32
func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, index) {
%index_tensor = arith.index_cast %tensor : tensor<i32> to tensor<index>
%index_scalar = arith.index_cast %scalar : i32 to index
return %index_tensor, %index_scalar : tensor<index>, index
}
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32, #[[$MAP]]>
// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
// CHECK-SAME: memref<i32, #[[$MAP]]> to memref<index, #[[$MAP]]>
// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
// CHECK: return %[[INDEX_TENSOR]]