forked from OSchip/llvm-project
[mlir] Allow empty position in vector.insert and vector.extract
Such ops are no-ops and are folded to their respective `source`/`vector` operand. Differential Revision: https://reviews.llvm.org/D101879
This commit is contained in:
parent
c52cbe63e4
commit
864adf399e
|
@ -764,7 +764,9 @@ def Vector_InsertOp :
|
|||
return dest().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Vector_InsertSlicesOp :
|
||||
|
|
|
@ -656,6 +656,12 @@ public:
|
|||
if (!llvmResultType)
|
||||
return failure();
|
||||
|
||||
// Extract entire vector. Should be handled by folder, but just to be safe.
|
||||
if (positionArrayAttr.empty()) {
|
||||
rewriter.replaceOp(extractOp, adaptor.vector());
|
||||
return success();
|
||||
}
|
||||
|
||||
// One-shot extraction of vector from array (only requires extractvalue).
|
||||
if (resultType.isa<VectorType>()) {
|
||||
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
|
||||
|
@ -762,6 +768,13 @@ public:
|
|||
if (!llvmResultType)
|
||||
return failure();
|
||||
|
||||
// Overwrite entire vector with value. Should be handled by folder, but
|
||||
// just to be safe.
|
||||
if (positionArrayAttr.empty()) {
|
||||
rewriter.replaceOp(insertOp, adaptor.source());
|
||||
return success();
|
||||
}
|
||||
|
||||
// One-shot insertion of a vector into an array (only requires insertvalue).
|
||||
if (sourceType.isa<VectorType>()) {
|
||||
Value inserted = rewriter.create<LLVM::InsertValueOp>(
|
||||
|
|
|
@ -872,8 +872,6 @@ static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
|
|||
|
||||
static LogicalResult verify(vector::ExtractOp op) {
|
||||
auto positionAttr = op.position().getValue();
|
||||
if (positionAttr.empty())
|
||||
return op.emitOpError("expected non-empty position attribute");
|
||||
if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
|
||||
return op.emitOpError(
|
||||
"expected position attribute of rank smaller than vector rank");
|
||||
|
@ -1151,6 +1149,8 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
|
|||
}
|
||||
|
||||
OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
|
||||
if (position().empty())
|
||||
return vector();
|
||||
if (succeeded(foldExtractOpFromExtractChain(*this)))
|
||||
return getResult();
|
||||
if (succeeded(foldExtractOpFromTranspose(*this)))
|
||||
|
@ -1557,8 +1557,6 @@ void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
|
|||
|
||||
static LogicalResult verify(InsertOp op) {
|
||||
auto positionAttr = op.position().getValue();
|
||||
if (positionAttr.empty())
|
||||
return op.emitOpError("expected non-empty position attribute");
|
||||
auto destVectorType = op.getDestVectorType();
|
||||
if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
|
||||
return op.emitOpError(
|
||||
|
@ -1612,6 +1610,15 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
results.add<InsertToShapeCast>(context);
|
||||
}
|
||||
|
||||
// Eliminates insert operations that produce values identical to their source
|
||||
// value. This happens when the source and destination vectors have identical
|
||||
// sizes.
|
||||
OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (position().empty())
|
||||
return source();
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertSlicesOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -80,13 +80,6 @@ func @extract_vector_type(%arg0: index) {
|
|||
|
||||
// -----
|
||||
|
||||
func @extract_position_empty(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected non-empty position attribute}}
|
||||
%1 = vector.extract %arg0[] : vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute of rank smaller than vector}}
|
||||
%1 = vector.extract %arg0[0, 0, 0, 0] : vector<4x8x16xf32>
|
||||
|
@ -138,13 +131,6 @@ func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected non-empty position attribute}}
|
||||
%1 = vector.insert %a, %b[] : f32 into vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute of rank smaller than dest vector rank}}
|
||||
%1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32>
|
||||
|
|
|
@ -158,14 +158,16 @@ func @extract_element(%a: vector<16xf32>) -> f32 {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: @extract
|
||||
func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) {
|
||||
func @extract(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) {
|
||||
// CHECK: vector.extract {{.*}}[] : vector<4x8x16xf32>
|
||||
%0 = vector.extract %arg0[] : vector<4x8x16xf32>
|
||||
// CHECK: vector.extract {{.*}}[3] : vector<4x8x16xf32>
|
||||
%1 = vector.extract %arg0[3] : vector<4x8x16xf32>
|
||||
// CHECK-NEXT: vector.extract {{.*}}[3, 3] : vector<4x8x16xf32>
|
||||
%2 = vector.extract %arg0[3, 3] : vector<4x8x16xf32>
|
||||
// CHECK-NEXT: vector.extract {{.*}}[3, 3, 3] : vector<4x8x16xf32>
|
||||
%3 = vector.extract %arg0[3, 3, 3] : vector<4x8x16xf32>
|
||||
return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32
|
||||
return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @insert_element
|
||||
|
@ -185,7 +187,9 @@ func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8
|
|||
%2 = vector.insert %b, %res[3, 3] : vector<16xf32> into vector<4x8x16xf32>
|
||||
// CHECK: vector.insert %{{.*}}, %{{.*}}[3, 3, 3] : f32 into vector<4x8x16xf32>
|
||||
%3 = vector.insert %a, %res[3, 3, 3] : f32 into vector<4x8x16xf32>
|
||||
return %3 : vector<4x8x16xf32>
|
||||
// CHECK: vector.insert %{{.*}}, %{{.*}}[] : vector<4x8x16xf32> into vector<4x8x16xf32>
|
||||
%4 = vector.insert %3, %3[] : vector<4x8x16xf32> into vector<4x8x16xf32>
|
||||
return %4 : vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @outerproduct
|
||||
|
|
Loading…
Reference in New Issue