[mlir][Vector] Thread 0-d vectors through InsertElementOp.

This revision makes concrete use of 0-d vectors to extend the semantics of
InsertElementOp.

Reviewed By: dcaballe, pifon2a

Differential Revision: https://reviews.llvm.org/D114388
This commit is contained in:
Nicolas Vasilache 2021-11-23 12:01:53 +00:00
parent e7026aba00
commit 3ff4e5f2a4
7 changed files with 87 additions and 14 deletions

View File

@ -666,16 +666,18 @@ def Vector_InsertElementOp :
"result", "source",
"$_self.cast<ShapedType>().getElementType()">,
AllTypesMatch<["dest", "result"]>]>,
Arguments<(ins AnyType:$source, AnyVector:$dest,
AnySignlessIntegerOrIndex:$position)>,
Results<(outs AnyVector:$result)> {
Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest,
Optional<AnySignlessIntegerOrIndex>:$position)>,
Results<(outs AnyVectorOfAnyRank:$result)> {
let summary = "insertelement operation";
let description = [{
Takes a scalar source, an 1-D destination vector and a dynamic index
position and inserts the source into the destination at the proper
position. Note that this instruction resembles vector.insert, but
is restricted to 1-D vectors and relaxed to dynamic indices. It is
meant to be closer to LLVM's version:
Takes a scalar source, a 0-D or 1-D destination vector and a dynamic index
position and inserts the source into the destination at the proper position.
Note that this instruction resembles vector.insert, but is restricted to 0-D
and 1-D vectors and relaxed to dynamic indices.
It is meant to be closer to LLVM's version:
https://llvm.org/docs/LangRef.html#insertelement-instruction
Example:
@ -684,14 +686,18 @@ def Vector_InsertElementOp :
%c = arith.constant 15 : i32
%f = arith.constant 0.0f : f32
%1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32>
%2 = vector.insertelement %f, %z[]: vector<f32>
```
}];
let assemblyFormat = [{
$source `,` $dest `[` $position `:` type($position) `]` attr-dict `:`
$source `,` $dest `[` ($position^ `:` type($position))? `]` attr-dict `:`
type($result)
}];
let builders = [
// 0-D builder.
OpBuilder<(ins "Value":$source, "Value":$dest)>,
// 1-D + position builder.
OpBuilder<(ins "Value":$source, "Value":$dest, "Value":$position)>
];
let extraClassDeclaration = [{

View File

@ -663,6 +663,17 @@ public:
if (!llvmType)
return failure();
if (vectorType.getRank() == 0) {
Location loc = insertEltOp.getLoc();
auto idxType = rewriter.getIndexType();
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(idxType),
rewriter.getIntegerAttr(idxType, 0));
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero);
return success();
}
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
adaptor.position());

View File

@ -1553,6 +1553,12 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
// InsertElementOp
//===----------------------------------------------------------------------===//
void InsertElementOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest) {
result.addOperands({source, dest});
result.addTypes(dest.getType());
}
void InsertElementOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest, Value position) {
result.addOperands({source, dest, position});
@ -1561,8 +1567,15 @@ void InsertElementOp::build(OpBuilder &builder, OperationState &result,
static LogicalResult verify(InsertElementOp op) {
auto dstVectorType = op.getDestVectorType();
if (dstVectorType.getRank() == 0) {
if (op.position())
return op.emitOpError("expected position to be empty with 0-D vector");
return success();
}
if (dstVectorType.getRank() != 1)
return op.emitOpError("expected 1-D vector");
return op.emitOpError("unexpected >1 vector rank");
if (!op.position())
return op.emitOpError("expected position for 1-D vector");
return success();
}

View File

@ -512,6 +512,19 @@ func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
// -----
// CHECK-LABEL: @insert_element_0d
// CHECK-SAME: %[[A:.*]]: f32,
func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
// CHECK: %[[B:.*]] = builtin.unrealized_conversion_cast %{{.*}} :
// CHECK: vector<f32> to vector<1xf32>
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[x:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C0]] : {{.*}}] : vector<1xf32>
%1 = vector.insertelement %a, %b[] : vector<f32>
return %1 : vector<f32>
}
// -----
func @insert_element(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
%0 = arith.constant 3 : i32
%1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<4xf32>

View File

@ -79,7 +79,7 @@ func @extract_element(%arg0: vector<f32>) {
}
// -----
func @extract_element(%arg0: vector<4xf32>) {
%c = arith.constant 3 : i32
// expected-error@+1 {{expected position for 1-D vector}}
@ -138,9 +138,25 @@ func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
// -----
func @insert_element(%arg0: f32, %arg1: vector<f32>) {
%c = arith.constant 3 : i32
// expected-error@+1 {{expected position to be empty with 0-D vector}}
%0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<f32>
}
// -----
func @insert_element(%arg0: f32, %arg1: vector<4xf32>) {
%c = arith.constant 3 : i32
// expected-error@+1 {{expected position for 1-D vector}}
%0 = vector.insertelement %arg0, %arg1[] : vector<4xf32>
}
// -----
func @insert_element(%arg0: f32, %arg1: vector<4x4xf32>) {
%c = arith.constant 3 : i32
// expected-error@+1 {{'vector.insertelement' op expected 1-D vector}}
// expected-error@+1 {{unexpected >1 vector rank}}
%0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<4x4xf32>
}

View File

@ -192,6 +192,13 @@ func @extract(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32
return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32
}
// CHECK-LABEL: @insert_element_0d
func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
// CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32>
%1 = vector.insertelement %a, %b[] : vector<f32>
return %1 : vector<f32>
}
// CHECK-LABEL: @insert_element
func @insert_element(%a: f32, %b: vector<16xf32>) -> vector<16xf32> {
// CHECK: %[[C15:.*]] = arith.constant 15 : i32

View File

@ -10,8 +10,15 @@ func @extract_element_0d(%a: vector<f32>) {
return
}
func @insert_element_0d(%a: f32, %b: vector<f32>) -> (vector<f32>) {
%1 = vector.insertelement %a, %b[] : vector<f32>
return %1: vector<f32>
}
func @entry() {
%1 = arith.constant dense<42.0> : vector<f32>
call @extract_element_0d(%1) : (vector<f32>) -> ()
%0 = arith.constant 42.0 : f32
%1 = arith.constant dense<0.0> : vector<f32>
%2 = call @insert_element_0d(%0, %1) : (f32, vector<f32>) -> (vector<f32>)
call @extract_element_0d(%2) : (vector<f32>) -> ()
return
}