forked from OSchip/llvm-project
[mlir][vector] Address post-commit review comments on vector ops folding patterns
Differential Revision: https://reviews.llvm.org/D90183
This commit is contained in:
parent
b85f2f5c5f
commit
9081e7594d
|
@ -850,10 +850,12 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
|
|||
return Value();
|
||||
// Get the nth dimension size starting from lowest dimension.
|
||||
auto getDimReverse = [](VectorType type, int64_t n) {
|
||||
return type.getDimSize(type.getRank() - n - 1);
|
||||
return type.getShape().take_back(n+1).front();
|
||||
};
|
||||
int64_t destinationRank =
|
||||
extractOp.getVectorType().getRank() - extractOp.position().size();
|
||||
extractOp.getType().isa<VectorType>()
|
||||
? extractOp.getType().cast<VectorType>().getRank()
|
||||
: 0;
|
||||
if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
|
||||
return Value();
|
||||
if (destinationRank > 0) {
|
||||
|
@ -861,6 +863,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
|
|||
for (int64_t i = 0; i < destinationRank; i++) {
|
||||
// The lowest dimension of of the destination must match the lowest
|
||||
// dimension of the shapecast op source.
|
||||
// TODO: This case could be support in a canonicalization pattern.
|
||||
if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
|
||||
getDimReverse(destinationType, i))
|
||||
return Value();
|
||||
|
@ -891,6 +894,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
|
|||
}
|
||||
std::reverse(newStrides.begin(), newStrides.end());
|
||||
SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
|
||||
// OpBuilder is only used as a helper to build an I64ArrayAttr.
|
||||
OpBuilder b(extractOp.getContext());
|
||||
extractOp.setAttr(ExtractOp::getPositionAttrName(),
|
||||
b.getI64ArrayAttr(newPosition));
|
||||
|
@ -1632,8 +1636,8 @@ static LogicalResult verify(ExtractStridedSliceOp op) {
|
|||
}
|
||||
|
||||
// When the source of ExtractStrided comes from a chain of InsertStrided ops try
|
||||
// to use the source o the InsertStrided ops if we can detect that the extracted
|
||||
// vector is a subset of one of the vector inserted.
|
||||
// to use the source of the InsertStrided ops if we can detect that the
|
||||
// extracted vector is a subset of one of the vector inserted.
|
||||
static LogicalResult
|
||||
foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
|
||||
// Helper to extract integer out of ArrayAttr.
|
||||
|
|
|
@ -160,20 +160,20 @@ func @extract_strided_fold_negative(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
|
|||
|
||||
// Case where we need to go through 2 level of insert element.
|
||||
// CHECK-LABEL: extract_strided_fold_insert
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>,
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x8xf32>, %[[ARG1:.*]]: vector<1x4xf32>,
|
||||
// CHECK-NEXT: %[[EXT:.*]] = vector.extract_strided_slice %[[ARG1]]
|
||||
// CHECK-SAME: {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
|
||||
// CHECK-SAME: {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]}
|
||||
// CHECK-SAME: : vector<1x4xf32> to vector<1x1xf32>
|
||||
// CHECK-NEXT: return %[[EXT]] : vector<1x1xf32>
|
||||
func @extract_strided_fold_insert(%a: vector<2x4xf32>, %b: vector<1x4xf32>,
|
||||
func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
|
||||
%c : vector<1x4xf32>) -> (vector<1x1xf32>) {
|
||||
%0 = vector.insert_strided_slice %b, %a {offsets = [0, 0], strides = [1, 1]}
|
||||
: vector<1x4xf32> into vector<2x4xf32>
|
||||
%0 = vector.insert_strided_slice %b, %a {offsets = [0, 1], strides = [1, 1]}
|
||||
: vector<1x4xf32> into vector<2x8xf32>
|
||||
%1 = vector.insert_strided_slice %c, %0 {offsets = [1, 0], strides = [1, 1]}
|
||||
: vector<1x4xf32> into vector<2x4xf32>
|
||||
: vector<1x4xf32> into vector<2x8xf32>
|
||||
%2 = vector.extract_strided_slice %1
|
||||
{offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
|
||||
: vector<2x4xf32> to vector<1x1xf32>
|
||||
: vector<2x8xf32> to vector<1x1xf32>
|
||||
return %2 : vector<1x1xf32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue