[mlir] Allow empty position in vector.insert and vector.extract

Such ops are no-ops and are folded to their respective `source`/`vector` operand.

Differential Revision: https://reviews.llvm.org/D101879
This commit is contained in:
Matthias Springer 2021-05-13 12:53:15 +09:00
parent c52cbe63e4
commit 864adf399e
5 changed files with 33 additions and 21 deletions

View File

@ -764,7 +764,9 @@ def Vector_InsertOp :
return dest().getType().cast<VectorType>();
}
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
}
def Vector_InsertSlicesOp :

View File

@ -656,6 +656,12 @@ public:
if (!llvmResultType)
return failure();
// Extract entire vector. Should be handled by folder, but just to be safe.
if (positionArrayAttr.empty()) {
rewriter.replaceOp(extractOp, adaptor.vector());
return success();
}
// One-shot extraction of vector from array (only requires extractvalue).
if (resultType.isa<VectorType>()) {
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
@ -762,6 +768,13 @@ public:
if (!llvmResultType)
return failure();
// Overwrite entire vector with value. Should be handled by folder, but
// just to be safe.
if (positionArrayAttr.empty()) {
rewriter.replaceOp(insertOp, adaptor.source());
return success();
}
// One-shot insertion of a vector into an array (only requires insertvalue).
if (sourceType.isa<VectorType>()) {
Value inserted = rewriter.create<LLVM::InsertValueOp>(

View File

@ -872,8 +872,6 @@ static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
static LogicalResult verify(vector::ExtractOp op) {
auto positionAttr = op.position().getValue();
if (positionAttr.empty())
return op.emitOpError("expected non-empty position attribute");
if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
return op.emitOpError(
"expected position attribute of rank smaller than vector rank");
@ -1151,6 +1149,8 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
}
OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
if (position().empty())
return vector();
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
if (succeeded(foldExtractOpFromTranspose(*this)))
@ -1557,8 +1557,6 @@ void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
static LogicalResult verify(InsertOp op) {
auto positionAttr = op.position().getValue();
if (positionAttr.empty())
return op.emitOpError("expected non-empty position attribute");
auto destVectorType = op.getDestVectorType();
if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
return op.emitOpError(
@ -1612,6 +1610,15 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<InsertToShapeCast>(context);
}
// Eliminates insert operations that produce values identical to their source
// value. This happens when the source and destination vectors have identical
// sizes.
OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
if (position().empty())
return source();
return {};
}
//===----------------------------------------------------------------------===//
// InsertSlicesOp
//===----------------------------------------------------------------------===//

View File

@ -80,13 +80,6 @@ func @extract_vector_type(%arg0: index) {
// -----
func @extract_position_empty(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected non-empty position attribute}}
%1 = vector.extract %arg0[] : vector<4x8x16xf32>
}
// -----
func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute of rank smaller than vector}}
%1 = vector.extract %arg0[0, 0, 0, 0] : vector<4x8x16xf32>
@ -138,13 +131,6 @@ func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) {
// -----
func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
// expected-error@+1 {{expected non-empty position attribute}}
%1 = vector.insert %a, %b[] : f32 into vector<4x8x16xf32>
}
// -----
func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute of rank smaller than dest vector rank}}
%1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32>

View File

@ -158,14 +158,16 @@ func @extract_element(%a: vector<16xf32>) -> f32 {
}
// CHECK-LABEL: @extract
func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) {
func @extract(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) {
// CHECK: vector.extract {{.*}}[] : vector<4x8x16xf32>
%0 = vector.extract %arg0[] : vector<4x8x16xf32>
// CHECK: vector.extract {{.*}}[3] : vector<4x8x16xf32>
%1 = vector.extract %arg0[3] : vector<4x8x16xf32>
// CHECK-NEXT: vector.extract {{.*}}[3, 3] : vector<4x8x16xf32>
%2 = vector.extract %arg0[3, 3] : vector<4x8x16xf32>
// CHECK-NEXT: vector.extract {{.*}}[3, 3, 3] : vector<4x8x16xf32>
%3 = vector.extract %arg0[3, 3, 3] : vector<4x8x16xf32>
return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32
return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32
}
// CHECK-LABEL: @insert_element
@ -185,7 +187,9 @@ func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8
%2 = vector.insert %b, %res[3, 3] : vector<16xf32> into vector<4x8x16xf32>
// CHECK: vector.insert %{{.*}}, %{{.*}}[3, 3, 3] : f32 into vector<4x8x16xf32>
%3 = vector.insert %a, %res[3, 3, 3] : f32 into vector<4x8x16xf32>
return %3 : vector<4x8x16xf32>
// CHECK: vector.insert %{{.*}}, %{{.*}}[] : vector<4x8x16xf32> into vector<4x8x16xf32>
%4 = vector.insert %3, %3[] : vector<4x8x16xf32> into vector<4x8x16xf32>
return %4 : vector<4x8x16xf32>
}
// CHECK-LABEL: @outerproduct