From 864adf399e58a6bfd823136fc2cbcfe9dff5b4a8 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 13 May 2021 12:53:15 +0900 Subject: [PATCH] [mlir] Allow empty position in vector.insert and vector.extract Such ops are no-ops and are folded to their respective `source`/`vector` operand. Differential Revision: https://reviews.llvm.org/D101879 --- mlir/include/mlir/Dialect/Vector/VectorOps.td | 2 ++ .../VectorToLLVM/ConvertVectorToLLVM.cpp | 13 +++++++++++++ mlir/lib/Dialect/Vector/VectorOps.cpp | 15 +++++++++++---- mlir/test/Dialect/Vector/invalid.mlir | 14 -------------- mlir/test/Dialect/Vector/ops.mlir | 10 +++++++--- 5 files changed, 33 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td index 45c0ccaa0928..6c621f93c024 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -764,7 +764,9 @@ def Vector_InsertOp : return dest().getType().cast(); } }]; + let hasCanonicalizer = 1; + let hasFolder = 1; } def Vector_InsertSlicesOp : diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 9ecee857e2e5..9db34d7411e9 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -656,6 +656,12 @@ public: if (!llvmResultType) return failure(); + // Extract entire vector. Should be handled by folder, but just to be safe. + if (positionArrayAttr.empty()) { + rewriter.replaceOp(extractOp, adaptor.vector()); + return success(); + } + // One-shot extraction of vector from array (only requires extractvalue). if (resultType.isa()) { Value extracted = rewriter.create( @@ -762,6 +768,13 @@ public: if (!llvmResultType) return failure(); + // Overwrite entire vector with value. Should be handled by folder, but + // just to be safe. + if (positionArrayAttr.empty()) { + rewriter.replaceOp(insertOp, adaptor.source()); + return success(); + } + // One-shot insertion of a vector into an array (only requires insertvalue). if (sourceType.isa()) { Value inserted = rewriter.create( diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index c86817cdc3ab..f24b9171203a 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -872,8 +872,6 @@ static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) { static LogicalResult verify(vector::ExtractOp op) { auto positionAttr = op.position().getValue(); - if (positionAttr.empty()) - return op.emitOpError("expected non-empty position attribute"); if (positionAttr.size() > static_cast(op.getVectorType().getRank())) return op.emitOpError( "expected position attribute of rank smaller than vector rank"); @@ -1151,6 +1149,8 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) { } OpFoldResult ExtractOp::fold(ArrayRef) { + if (position().empty()) + return vector(); if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); if (succeeded(foldExtractOpFromTranspose(*this))) @@ -1557,8 +1557,6 @@ void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, static LogicalResult verify(InsertOp op) { auto positionAttr = op.position().getValue(); - if (positionAttr.empty()) - return op.emitOpError("expected non-empty position attribute"); auto destVectorType = op.getDestVectorType(); if (positionAttr.size() > static_cast(destVectorType.getRank())) return op.emitOpError( @@ -1612,6 +1610,15 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +// Eliminates insert operations that produce values identical to their source +// value. This happens when the source and destination vectors have identical +// sizes. +OpFoldResult vector::InsertOp::fold(ArrayRef operands) { + if (position().empty()) + return source(); + return {}; +} + //===----------------------------------------------------------------------===// // InsertSlicesOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 545a3ac8c463..e06380df3f66 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -80,13 +80,6 @@ func @extract_vector_type(%arg0: index) { // ----- -func @extract_position_empty(%arg0: vector<4x8x16xf32>) { - // expected-error@+1 {{expected non-empty position attribute}} - %1 = vector.extract %arg0[] : vector<4x8x16xf32> -} - -// ----- - func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute of rank smaller than vector}} %1 = vector.extract %arg0[0, 0, 0, 0] : vector<4x8x16xf32> @@ -138,13 +131,6 @@ func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) { // ----- -func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) { - // expected-error@+1 {{expected non-empty position attribute}} - %1 = vector.insert %a, %b[] : f32 into vector<4x8x16xf32> -} - -// ----- - func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute of rank smaller than dest vector rank}} %1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index c3bb8fffbb1a..8beff28ef8a0 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -158,14 +158,16 @@ func @extract_element(%a: vector<16xf32>) -> f32 { } // CHECK-LABEL: @extract -func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) { +func @extract(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) { + // CHECK: vector.extract {{.*}}[] : vector<4x8x16xf32> + %0 = vector.extract %arg0[] : vector<4x8x16xf32> // CHECK: vector.extract {{.*}}[3] : vector<4x8x16xf32> %1 = vector.extract %arg0[3] : vector<4x8x16xf32> // CHECK-NEXT: vector.extract {{.*}}[3, 3] : vector<4x8x16xf32> %2 = vector.extract %arg0[3, 3] : vector<4x8x16xf32> // CHECK-NEXT: vector.extract {{.*}}[3, 3, 3] : vector<4x8x16xf32> %3 = vector.extract %arg0[3, 3, 3] : vector<4x8x16xf32> - return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32 + return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32 } // CHECK-LABEL: @insert_element @@ -185,7 +187,9 @@ func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8 %2 = vector.insert %b, %res[3, 3] : vector<16xf32> into vector<4x8x16xf32> // CHECK: vector.insert %{{.*}}, %{{.*}}[3, 3, 3] : f32 into vector<4x8x16xf32> %3 = vector.insert %a, %res[3, 3, 3] : f32 into vector<4x8x16xf32> - return %3 : vector<4x8x16xf32> + // CHECK: vector.insert %{{.*}}, %{{.*}}[] : vector<4x8x16xf32> into vector<4x8x16xf32> + %4 = vector.insert %3, %3[] : vector<4x8x16xf32> into vector<4x8x16xf32> + return %4 : vector<4x8x16xf32> } // CHECK-LABEL: @outerproduct