forked from OSchip/llvm-project
[mlir][Vector] Thread 0-d vectors through InsertElementOp.
This revision makes concrete use of 0-d vectors to extend the semantics of InsertElementOp. Reviewed By: dcaballe, pifon2a Differential Revision: https://reviews.llvm.org/D114388
This commit is contained in:
parent
e7026aba00
commit
3ff4e5f2a4
|
@ -666,16 +666,18 @@ def Vector_InsertElementOp :
|
|||
"result", "source",
|
||||
"$_self.cast<ShapedType>().getElementType()">,
|
||||
AllTypesMatch<["dest", "result"]>]>,
|
||||
Arguments<(ins AnyType:$source, AnyVector:$dest,
|
||||
AnySignlessIntegerOrIndex:$position)>,
|
||||
Results<(outs AnyVector:$result)> {
|
||||
Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest,
|
||||
Optional<AnySignlessIntegerOrIndex>:$position)>,
|
||||
Results<(outs AnyVectorOfAnyRank:$result)> {
|
||||
let summary = "insertelement operation";
|
||||
let description = [{
|
||||
Takes a scalar source, an 1-D destination vector and a dynamic index
|
||||
position and inserts the source into the destination at the proper
|
||||
position. Note that this instruction resembles vector.insert, but
|
||||
is restricted to 1-D vectors and relaxed to dynamic indices. It is
|
||||
meant to be closer to LLVM's version:
|
||||
Takes a scalar source, a 0-D or 1-D destination vector and a dynamic index
|
||||
position and inserts the source into the destination at the proper position.
|
||||
|
||||
Note that this instruction resembles vector.insert, but is restricted to 0-D
|
||||
and 1-D vectors and relaxed to dynamic indices.
|
||||
|
||||
It is meant to be closer to LLVM's version:
|
||||
https://llvm.org/docs/LangRef.html#insertelement-instruction
|
||||
|
||||
Example:
|
||||
|
@ -684,14 +686,18 @@ def Vector_InsertElementOp :
|
|||
%c = arith.constant 15 : i32
|
||||
%f = arith.constant 0.0f : f32
|
||||
%1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32>
|
||||
%2 = vector.insertelement %f, %z[]: vector<f32>
|
||||
```
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
$source `,` $dest `[` $position `:` type($position) `]` attr-dict `:`
|
||||
$source `,` $dest `[` ($position^ `:` type($position))? `]` attr-dict `:`
|
||||
type($result)
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
// 0-D builder.
|
||||
OpBuilder<(ins "Value":$source, "Value":$dest)>,
|
||||
// 1-D + position builder.
|
||||
OpBuilder<(ins "Value":$source, "Value":$dest, "Value":$position)>
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
|
|
|
@ -663,6 +663,17 @@ public:
|
|||
if (!llvmType)
|
||||
return failure();
|
||||
|
||||
if (vectorType.getRank() == 0) {
|
||||
Location loc = insertEltOp.getLoc();
|
||||
auto idxType = rewriter.getIndexType();
|
||||
auto zero = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, typeConverter->convertType(idxType),
|
||||
rewriter.getIntegerAttr(idxType, 0));
|
||||
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
|
||||
insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero);
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
|
||||
insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
|
||||
adaptor.position());
|
||||
|
|
|
@ -1553,6 +1553,12 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
|
|||
// InsertElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void InsertElementOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value source, Value dest) {
|
||||
result.addOperands({source, dest});
|
||||
result.addTypes(dest.getType());
|
||||
}
|
||||
|
||||
void InsertElementOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value source, Value dest, Value position) {
|
||||
result.addOperands({source, dest, position});
|
||||
|
@ -1561,8 +1567,15 @@ void InsertElementOp::build(OpBuilder &builder, OperationState &result,
|
|||
|
||||
static LogicalResult verify(InsertElementOp op) {
|
||||
auto dstVectorType = op.getDestVectorType();
|
||||
if (dstVectorType.getRank() == 0) {
|
||||
if (op.position())
|
||||
return op.emitOpError("expected position to be empty with 0-D vector");
|
||||
return success();
|
||||
}
|
||||
if (dstVectorType.getRank() != 1)
|
||||
return op.emitOpError("expected 1-D vector");
|
||||
return op.emitOpError("unexpected >1 vector rank");
|
||||
if (!op.position())
|
||||
return op.emitOpError("expected position for 1-D vector");
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -512,6 +512,19 @@ func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @insert_element_0d
|
||||
// CHECK-SAME: %[[A:.*]]: f32,
|
||||
func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
|
||||
// CHECK: %[[B:.*]] = builtin.unrealized_conversion_cast %{{.*}} :
|
||||
// CHECK: vector<f32> to vector<1xf32>
|
||||
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
|
||||
// CHECK: %[[x:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C0]] : {{.*}}] : vector<1xf32>
|
||||
%1 = vector.insertelement %a, %b[] : vector<f32>
|
||||
return %1 : vector<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insert_element(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
|
||||
%0 = arith.constant 3 : i32
|
||||
%1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<4xf32>
|
||||
|
|
|
@ -79,7 +79,7 @@ func @extract_element(%arg0: vector<f32>) {
|
|||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func @extract_element(%arg0: vector<4xf32>) {
|
||||
%c = arith.constant 3 : i32
|
||||
// expected-error@+1 {{expected position for 1-D vector}}
|
||||
|
@ -138,9 +138,25 @@ func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @insert_element(%arg0: f32, %arg1: vector<f32>) {
|
||||
%c = arith.constant 3 : i32
|
||||
// expected-error@+1 {{expected position to be empty with 0-D vector}}
|
||||
%0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insert_element(%arg0: f32, %arg1: vector<4xf32>) {
|
||||
%c = arith.constant 3 : i32
|
||||
// expected-error@+1 {{expected position for 1-D vector}}
|
||||
%0 = vector.insertelement %arg0, %arg1[] : vector<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insert_element(%arg0: f32, %arg1: vector<4x4xf32>) {
|
||||
%c = arith.constant 3 : i32
|
||||
// expected-error@+1 {{'vector.insertelement' op expected 1-D vector}}
|
||||
// expected-error@+1 {{unexpected >1 vector rank}}
|
||||
%0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<4x4xf32>
|
||||
}
|
||||
|
||||
|
|
|
@ -192,6 +192,13 @@ func @extract(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32
|
|||
return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @insert_element_0d
|
||||
func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
|
||||
// CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32>
|
||||
%1 = vector.insertelement %a, %b[] : vector<f32>
|
||||
return %1 : vector<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @insert_element
|
||||
func @insert_element(%a: f32, %b: vector<16xf32>) -> vector<16xf32> {
|
||||
// CHECK: %[[C15:.*]] = arith.constant 15 : i32
|
||||
|
|
|
@ -10,8 +10,15 @@ func @extract_element_0d(%a: vector<f32>) {
|
|||
return
|
||||
}
|
||||
|
||||
func @insert_element_0d(%a: f32, %b: vector<f32>) -> (vector<f32>) {
|
||||
%1 = vector.insertelement %a, %b[] : vector<f32>
|
||||
return %1: vector<f32>
|
||||
}
|
||||
|
||||
func @entry() {
|
||||
%1 = arith.constant dense<42.0> : vector<f32>
|
||||
call @extract_element_0d(%1) : (vector<f32>) -> ()
|
||||
%0 = arith.constant 42.0 : f32
|
||||
%1 = arith.constant dense<0.0> : vector<f32>
|
||||
%2 = call @insert_element_0d(%0, %1) : (f32, vector<f32>) -> (vector<f32>)
|
||||
call @extract_element_0d(%2) : (vector<f32>) -> ()
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue