forked from OSchip/llvm-project
[mlir][linalg] Fix pad tensor cast folding with changed type
`PadTensorOp` has verification logic to make sure result dim must be static if all the padding values are static. Cast folding might add more static information for the src operand of `PadTensorOp` which might change a valid operation to be invalid. Change the canonicalizing pattern to fix this.
This commit is contained in:
parent
b06426da76
commit
9a82482313
|
@ -1229,9 +1229,26 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
|
|||
if (!tensor::canFoldIntoConsumerOp(castOp))
|
||||
return failure();
|
||||
|
||||
rewriter.updateRootInPlace(padTensorOp, [&]() {
|
||||
padTensorOp.sourceMutable().assign(castOp.source());
|
||||
});
|
||||
auto newResultType = PadTensorOp::inferResultType(
|
||||
castOp.source().getType().cast<RankedTensorType>(),
|
||||
extractFromI64ArrayAttr(padTensorOp.static_low()),
|
||||
extractFromI64ArrayAttr(padTensorOp.static_high()));
|
||||
|
||||
if (newResultType == padTensorOp.getResultType()) {
|
||||
rewriter.updateRootInPlace(padTensorOp, [&]() {
|
||||
padTensorOp.sourceMutable().assign(castOp.source());
|
||||
});
|
||||
} else {
|
||||
auto newOp = rewriter.create<PadTensorOp>(
|
||||
padTensorOp->getLoc(), newResultType, padTensorOp.source(),
|
||||
padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(),
|
||||
padTensorOp.static_high(), /*output=*/nullptr);
|
||||
BlockAndValueMapping mapper;
|
||||
padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
padTensorOp, padTensorOp.getResultType(), newOp);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -627,6 +627,55 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
|
|||
} : tensor<5x6xf32> to tensor<5x6xf32>
|
||||
return %0 : tensor<5x6xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @pad_tensor_after_cast_differnt_shape(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
|
||||
// CHECK: %[[CST:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[PADDED:.*]] = linalg.pad_tensor %[[INPUT]]
|
||||
// CHECK-SAME: low[0, 0, 1, 1] high[0, 0, 1, 1] {
|
||||
// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
|
||||
// CHECK: linalg.yield %[[CST]] : f32
|
||||
// CHECK: } : tensor<?x64x?x?xf32> to tensor<?x64x?x?xf32>
|
||||
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[PADDED:.*]] :
|
||||
// CHECK-SAME: tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
|
||||
// CHECK: return %[[DYNAMIC]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: }
|
||||
func @pad_tensor_after_cast_differnt_shape(%arg0: tensor<?x64x?x?xf32>)
|
||||
-> tensor<?x?x?x?xf32> {
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
%dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
|
||||
%padded = linalg.pad_tensor %dynamic low[0, 0, 1, 1] high[0, 0, 1, 1] {
|
||||
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): // no predecessors
|
||||
linalg.yield %cst: f32
|
||||
} : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
|
||||
return %padded: tensor<?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @pad_tensor_after_cast_same_shape(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>,
|
||||
// CHECK-SAME: %[[PADDING:.*]]: index) -> tensor<?x?x?x?xf32> {
|
||||
// CHECK: %[[CST:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[PADDED:.*]] = linalg.pad_tensor %[[INPUT]]
|
||||
// CHECK-SAME: low[0, %[[PADDING]], 1, 1] high[0, %[[PADDING]], 1, 1] {
|
||||
// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
|
||||
// CHECK: linalg.yield %[[CST]] : f32
|
||||
// CHECK: } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
|
||||
// CHECK: return %[[PADDED:.*]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: }
|
||||
func @pad_tensor_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : index)
|
||||
-> tensor<?x?x?x?xf32> {
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
%dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
|
||||
%padded = linalg.pad_tensor %dynamic low[0, %padding, 1, 1] high[0, %padding, 1, 1] {
|
||||
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): // no predecessors
|
||||
linalg.yield %cst: f32
|
||||
} : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
|
||||
return %padded: tensor<?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
|
||||
%arg3 : index) -> tensor<?x?xf32> {
|
||||
%c0 = constant 0 : index
|
||||
|
|
Loading…
Reference in New Issue