forked from OSchip/llvm-project
[mlir] Add folding of tensor.cast -> subtensor_insert
Differential Revision: https://reviews.llvm.org/D97059
This commit is contained in:
parent
236aab0b0c
commit
0ee4bf151c
|
@ -61,6 +61,10 @@ namespace tensor {
|
|||
/// ```
|
||||
bool canFoldIntoConsumerOp(CastOp castOp);
|
||||
|
||||
/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
|
||||
/// that can be folded.
|
||||
LogicalResult foldTensorCast(Operation *op);
|
||||
|
||||
} // namespace tensor
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -3790,6 +3790,8 @@ OpFoldResult SubTensorInsertOp::fold(ArrayRef<Attribute>) {
|
|||
if (getSourceType() == getType() &&
|
||||
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
|
||||
return this->source();
|
||||
if (succeeded(tensor::foldTensorCast(*this)))
|
||||
return this->source();
|
||||
return OpFoldResult();
|
||||
}
|
||||
|
||||
|
|
|
@ -73,6 +73,20 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
|
|||
return true;
|
||||
}
|
||||
|
||||
/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
|
||||
/// that can be folded.
|
||||
LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
|
||||
bool folded = false;
|
||||
for (OpOperand &operand : op->getOpOperands()) {
|
||||
auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
|
||||
if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
|
||||
operand.set(castOp.getOperand());
|
||||
folded = true;
|
||||
}
|
||||
}
|
||||
return success(folded);
|
||||
}
|
||||
|
||||
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
|
|
|
@ -237,3 +237,18 @@ func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32x
|
|||
%1 = subtensor %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x?x16x32xi8> to tensor<16x32xi8>
|
||||
return %1 : tensor<16x32xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @rank_reducing_subtensor_insert_of_cast
|
||||
// CHECK-SAME: %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8>
|
||||
// CHECK-SAME: %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
|
||||
// CHECK: %[[S:.+]] = subtensor_insert %[[A]] into %[[B]][0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<16x32xi8> into tensor<4x6x16x32xi8>
|
||||
// Tensor cast is folded away.
|
||||
// CHECK-NOT: tensor.cast
|
||||
// CHECK: return %[[S]] : tensor<4x6x16x32xi8>
|
||||
func @rank_reducing_subtensor_insert_of_cast(%a : tensor<16x32xi8>, %b : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
|
||||
%cast = tensor.cast %a : tensor<16x32xi8> to tensor<?x32xi8>
|
||||
%res = subtensor_insert %cast into %b[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
|
||||
return %res : tensor<4x6x16x32xi8>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue