forked from OSchip/llvm-project
[mlir][linalg][bufferize] Support arith.index_cast bufferization
This is in preparation of switching `-tensor-constant-bufferize` and `-arith-bufferize` to BufferizableOpInterface-based implementations. Differential Revision: https://reviews.llvm.org/D118324
This commit is contained in:
parent
3b259a6842
commit
dbd1bbced9
|
@ -57,6 +57,49 @@ struct ConstantOpInterface
|
|||
}
|
||||
};
|
||||
|
||||
struct IndexCastOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<IndexCastOpInterface,
|
||||
arith::IndexCastOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const BufferizationState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
const BufferizationState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
const BufferizationState &state) const {
|
||||
return op->getResult(0);
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationState &state) const {
|
||||
auto castOp = cast<arith::IndexCastOp>(op);
|
||||
|
||||
Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/);
|
||||
auto sourceType = source.getType().cast<BaseMemRefType>();
|
||||
|
||||
// Result type should have same layout and address space as the source type.
|
||||
MemRefLayoutAttrInterface layout = {};
|
||||
if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>())
|
||||
layout = rankedMemRefType.getLayout();
|
||||
Type resultType =
|
||||
getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(),
|
||||
layout, sourceType.getMemorySpace());
|
||||
|
||||
replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, source,
|
||||
resultType);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace arith_ext
|
||||
} // namespace comprehensive_bufferize
|
||||
} // namespace linalg
|
||||
|
@ -65,4 +108,6 @@ struct ConstantOpInterface
|
|||
void mlir::linalg::comprehensive_bufferize::arith_ext::
|
||||
registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
|
||||
registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>();
|
||||
registry
|
||||
.addOpInterface<arith::IndexCastOp, arith_ext::IndexCastOpInterface>();
|
||||
}
|
||||
|
|
|
@ -96,3 +96,19 @@ func @rank_reducing(
|
|||
}
|
||||
return %5: tensor<?x1x6x8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0)>
|
||||
// CHECK-LABEL: func @index_cast(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<i32>, %[[SCALAR:.*]]: i32
|
||||
func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, index) {
|
||||
%index_tensor = arith.index_cast %tensor : tensor<i32> to tensor<index>
|
||||
%index_scalar = arith.index_cast %scalar : i32 to index
|
||||
return %index_tensor, %index_scalar : tensor<index>, index
|
||||
}
|
||||
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32, #[[$MAP]]>
|
||||
// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
|
||||
// CHECK-SAME: memref<i32, #[[$MAP]]> to memref<index, #[[$MAP]]>
|
||||
// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
|
||||
// CHECK: return %[[INDEX_TENSOR]]
|
||||
|
|
Loading…
Reference in New Issue