[mlir] Add folding of tensor.cast -> subtensor_insert

Differential Revision: https://reviews.llvm.org/D97059
This commit is contained in:
Nicolas Vasilache 2021-02-19 17:04:12 +00:00
parent 236aab0b0c
commit 0ee4bf151c
4 changed files with 35 additions and 0 deletions

View File

@ -61,6 +61,10 @@ namespace tensor {
/// ```
bool canFoldIntoConsumerOp(CastOp castOp);
/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
/// that can be folded.
LogicalResult foldTensorCast(Operation *op);
} // namespace tensor
} // namespace mlir

View File

@ -3790,6 +3790,8 @@ OpFoldResult SubTensorInsertOp::fold(ArrayRef<Attribute>) {
if (getSourceType() == getType() &&
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
return this->source();
if (succeeded(tensor::foldTensorCast(*this)))
return this->source();
return OpFoldResult();
}

View File

@ -73,6 +73,20 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
return true;
}
/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
/// that can be folded.
LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
operand.set(castOp.getOperand());
folded = true;
}
}
return success(folded);
}
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;

View File

@ -237,3 +237,18 @@ func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32x
%1 = subtensor %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x?x16x32xi8> to tensor<16x32xi8>
return %1 : tensor<16x32xi8>
}
// -----
// CHECK-LABEL: func @rank_reducing_subtensor_insert_of_cast
// CHECK-SAME: %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8>
// CHECK-SAME: %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
// CHECK: %[[S:.+]] = subtensor_insert %[[A]] into %[[B]][0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<16x32xi8> into tensor<4x6x16x32xi8>
// Tensor cast is folded away.
// CHECK-NOT: tensor.cast
// CHECK: return %[[S]] : tensor<4x6x16x32xi8>
func @rank_reducing_subtensor_insert_of_cast(%a : tensor<16x32xi8>, %b : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
%cast = tensor.cast %a : tensor<16x32xi8> to tensor<?x32xi8>
%res = subtensor_insert %cast into %b[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
return %res : tensor<4x6x16x32xi8>
}