forked from OSchip/llvm-project
[VecOps] Rename vector.[insert|extract]element to just vector.[insert|extract]
Since these operations lower to [insert|extract][element|value] at LLVM dialect level, neither element nor value would correctly reflect the meaning. PiperOrigin-RevId: 284240727
This commit is contained in:
parent
be3ed14658
commit
d37f27251f
|
@ -216,21 +216,21 @@ def Vector_BroadcastOp :
|
|||
}];
|
||||
}
|
||||
|
||||
def Vector_ExtractElementOp :
|
||||
Vector_Op<"extractelement", [NoSideEffect,
|
||||
def Vector_ExtractOp :
|
||||
Vector_Op<"extract", [NoSideEffect,
|
||||
PredOpTrait<"operand and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>]>,
|
||||
Arguments<(ins AnyVector:$vector, I32ArrayAttr:$position)>,
|
||||
Results<(outs AnyType)> {
|
||||
let summary = "extractelement operation";
|
||||
let summary = "extract operation";
|
||||
let description = [{
|
||||
Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
|
||||
the proper position. Degenerates to an element type in the 0-D case.
|
||||
|
||||
Examples:
|
||||
```
|
||||
%1 = vector.extractelement %0[3]: vector<4x8x16xf32>
|
||||
%2 = vector.extractelement %0[3, 3, 3]: vector<4x8x16xf32>
|
||||
%1 = vector.extract %0[3]: vector<4x8x16xf32>
|
||||
%2 = vector.extract %0[3, 3, 3]: vector<4x8x16xf32>
|
||||
```
|
||||
}];
|
||||
let builders = [OpBuilder<
|
||||
|
@ -243,15 +243,15 @@ def Vector_ExtractElementOp :
|
|||
}];
|
||||
}
|
||||
|
||||
def Vector_InsertElementOp :
|
||||
Vector_Op<"insertelement", [NoSideEffect,
|
||||
def Vector_InsertOp :
|
||||
Vector_Op<"insert", [NoSideEffect,
|
||||
PredOpTrait<"source operand and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>,
|
||||
PredOpTrait<"dest operand and result have same type",
|
||||
TCresIsSameAsOpBase<0, 1>>]>,
|
||||
Arguments<(ins AnyType:$source, AnyVector:$dest, I32ArrayAttr:$position)>,
|
||||
Results<(outs AnyVector)> {
|
||||
let summary = "insertelement operation";
|
||||
let summary = "insert operation";
|
||||
let description = [{
|
||||
Takes an n-D source vector, an (n+k)-D destination vector and a k-D position
|
||||
and inserts the n-D source into the (n+k)-D destination at the proper
|
||||
|
@ -259,9 +259,9 @@ def Vector_InsertElementOp :
|
|||
|
||||
Examples:
|
||||
```
|
||||
%2 = vector.insertelement %0, %1[3 : i32]:
|
||||
%2 = vector.insert %0, %1[3 : i32]:
|
||||
vector<8x16xf32> into vector<4x8x16xf32>
|
||||
%5 = vector.insertelement %3, %4[3 : i32, 3 : i32, 3 : i32]:
|
||||
%5 = vector.insert %3, %4[3 : i32, 3 : i32, 3 : i32]:
|
||||
f32 into vector<4x8x16xf32>
|
||||
```
|
||||
}];
|
||||
|
|
|
@ -238,15 +238,15 @@ class VectorExtractElementOpConversion : public LLVMOpLowering {
|
|||
public:
|
||||
explicit VectorExtractElementOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context,
|
||||
: LLVMOpLowering(vector::ExtractOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
|
||||
auto extractOp = cast<vector::ExtractElementOp>(op);
|
||||
auto adaptor = vector::ExtractOpOperandAdaptor(operands);
|
||||
auto extractOp = cast<vector::ExtractOp>(op);
|
||||
auto vectorType = extractOp.vector()->getType().cast<VectorType>();
|
||||
auto resultType = extractOp.getResult()->getType();
|
||||
auto llvmResultType = lowering.convertType(resultType);
|
||||
|
|
|
@ -324,35 +324,33 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractElementOp
|
||||
// ExtractOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static Type inferExtractElementOpResultType(VectorType vectorType,
|
||||
ArrayAttr position) {
|
||||
static Type inferExtractOpResultType(VectorType vectorType,
|
||||
ArrayAttr position) {
|
||||
if (static_cast<int64_t>(position.size()) == vectorType.getRank())
|
||||
return vectorType.getElementType();
|
||||
return VectorType::get(vectorType.getShape().drop_front(position.size()),
|
||||
vectorType.getElementType());
|
||||
}
|
||||
|
||||
void vector::ExtractElementOp::build(Builder *builder, OperationState &result,
|
||||
Value *source,
|
||||
ArrayRef<int32_t> position) {
|
||||
void vector::ExtractOp::build(Builder *builder, OperationState &result,
|
||||
Value *source, ArrayRef<int32_t> position) {
|
||||
result.addOperands(source);
|
||||
auto positionAttr = builder->getI32ArrayAttr(position);
|
||||
result.addTypes(inferExtractElementOpResultType(
|
||||
source->getType().cast<VectorType>(), positionAttr));
|
||||
result.addTypes(inferExtractOpResultType(source->getType().cast<VectorType>(),
|
||||
positionAttr));
|
||||
result.addAttribute(getPositionAttrName(), positionAttr);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, vector::ExtractElementOp op) {
|
||||
static void print(OpAsmPrinter &p, vector::ExtractOp op) {
|
||||
p << op.getOperationName() << " " << *op.vector() << op.position();
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"position"});
|
||||
p << " : " << op.vector()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseExtractElementOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
|
||||
llvm::SMLoc attributeLoc, typeLoc;
|
||||
SmallVector<NamedAttribute, 4> attrs;
|
||||
OpAsmParser::OperandType vector;
|
||||
|
@ -375,13 +373,13 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
|
|||
attributeLoc,
|
||||
"expected position attribute of rank smaller than vector rank");
|
||||
|
||||
Type resType = inferExtractElementOpResultType(vectorType, positionAttr);
|
||||
Type resType = inferExtractOpResultType(vectorType, positionAttr);
|
||||
result.attributes = attrs;
|
||||
return failure(parser.resolveOperand(vector, type, result.operands) ||
|
||||
parser.addTypeToList(resType, result.types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(vector::ExtractElementOp op) {
|
||||
static LogicalResult verify(vector::ExtractOp op) {
|
||||
auto positionAttr = op.position().getValue();
|
||||
if (positionAttr.empty())
|
||||
return op.emitOpError("expected non-empty position attribute");
|
||||
|
@ -447,29 +445,26 @@ static ParseResult parseBroadcastOp(OpAsmParser &parser,
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertElementOp
|
||||
// InsertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void InsertElementOp::build(Builder *builder, OperationState &result,
|
||||
Value *source, Value *dest,
|
||||
ArrayRef<int32_t> position) {
|
||||
void InsertOp::build(Builder *builder, OperationState &result, Value *source,
|
||||
Value *dest, ArrayRef<int32_t> position) {
|
||||
result.addOperands({source, dest});
|
||||
auto positionAttr = builder->getI32ArrayAttr(position);
|
||||
result.addTypes(dest->getType());
|
||||
result.addAttribute(getPositionAttrName(), positionAttr);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, InsertElementOp op) {
|
||||
static void print(OpAsmPrinter &p, InsertOp op) {
|
||||
p << op.getOperationName() << " " << *op.source() << ", " << *op.dest()
|
||||
<< op.position();
|
||||
p.printOptionalAttrDict(op.getAttrs(),
|
||||
{InsertElementOp::getPositionAttrName()});
|
||||
p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()});
|
||||
p << " : " << op.getSourceType();
|
||||
p << " into " << op.getDestVectorType();
|
||||
}
|
||||
|
||||
static ParseResult parseInsertElementOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
static ParseResult parseInsertOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<NamedAttribute, 4> attrs;
|
||||
OpAsmParser::OperandType source, dest;
|
||||
Type sourceType;
|
||||
|
@ -477,8 +472,7 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser,
|
|||
Attribute attr;
|
||||
return failure(parser.parseOperand(source) || parser.parseComma() ||
|
||||
parser.parseOperand(dest) ||
|
||||
parser.parseAttribute(attr,
|
||||
InsertElementOp::getPositionAttrName(),
|
||||
parser.parseAttribute(attr, InsertOp::getPositionAttrName(),
|
||||
result.attributes) ||
|
||||
parser.parseOptionalAttrDict(attrs) ||
|
||||
parser.parseColonType(sourceType) ||
|
||||
|
@ -488,7 +482,7 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser,
|
|||
parser.addTypeToList(destType, result.types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(InsertElementOp op) {
|
||||
static LogicalResult verify(InsertOp op) {
|
||||
auto positionAttr = op.position().getValue();
|
||||
if (positionAttr.empty())
|
||||
return op.emitOpError("expected non-empty position attribute");
|
||||
|
|
|
@ -231,7 +231,7 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector
|
|||
// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]">
|
||||
|
||||
func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> {
|
||||
%0 = vector.extractelement %arg0[0 : i32]: vector<4x3x16xf32>
|
||||
%0 = vector.extract %arg0[0 : i32]: vector<4x3x16xf32>
|
||||
return %0 : vector<3x16xf32>
|
||||
}
|
||||
// CHECK-LABEL: extract_vec_2d_from_vec_3d
|
||||
|
@ -239,7 +239,7 @@ func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32>
|
|||
// CHECK: llvm.return %{{.*}} : !llvm<"[3 x <16 x float>]">
|
||||
|
||||
func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
|
||||
%0 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32]: vector<4x3x16xf32>
|
||||
%0 = vector.extract %arg0[0 : i32, 0 : i32, 0 : i32]: vector<4x3x16xf32>
|
||||
return %0 : f32
|
||||
}
|
||||
// CHECK-LABEL: extract_element_from_vec_3d
|
||||
|
|
|
@ -31,79 +31,79 @@ func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @extract_element_vector_type(%arg0: index) {
|
||||
func @extract_vector_type(%arg0: index) {
|
||||
// expected-error@+1 {{expected vector type}}
|
||||
%1 = vector.extractelement %arg0[] : index
|
||||
%1 = vector.extract %arg0[] : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extractelement_position_empty(%arg0: vector<4x8x16xf32>) {
|
||||
func @extract_position_empty(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected non-empty position attribute}}
|
||||
%1 = vector.extractelement %arg0[] : vector<4x8x16xf32>
|
||||
%1 = vector.extract %arg0[] : vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extractelement_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
|
||||
func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute of rank smaller than vector}}
|
||||
%1 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<4x8x16xf32>
|
||||
%1 = vector.extract %arg0[0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extractelement_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
|
||||
func @extract_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute of rank smaller than vector}}
|
||||
%1 = "vector.extractelement" (%arg0) { position = [0 : i32, 0 : i32, 0 : i32, 0 : i32] } : (vector<4x8x16xf32>) -> (vector<16xf32>)
|
||||
%1 = "vector.extract" (%arg0) { position = [0 : i32, 0 : i32, 0 : i32, 0 : i32] } : (vector<4x8x16xf32>) -> (vector<16xf32>)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) {
|
||||
func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute #2 to be a non-negative integer smaller than the corresponding vector dimension}}
|
||||
%1 = vector.extractelement %arg0[0 : i32, 43 : i32, 0 : i32] : vector<4x8x16xf32>
|
||||
%1 = vector.extract %arg0[0 : i32, 43 : i32, 0 : i32] : vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) {
|
||||
func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
|
||||
%1 = vector.extractelement %arg0[0 : i32, 0 : i32, -1 : i32] : vector<4x8x16xf32>
|
||||
%1 = vector.extract %arg0[0 : i32, 0 : i32, -1 : i32] : vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insert_element_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
|
||||
func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected non-empty position attribute}}
|
||||
%1 = vector.insertelement %a, %b[] : f32 into vector<4x8x16xf32>
|
||||
%1 = vector.insert %a, %b[] : f32 into vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insert_element_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
|
||||
func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute of rank smaller than dest vector rank}}
|
||||
%1 = vector.insertelement %a, %b[3 : i32,3 : i32,3 : i32,3 : i32,3 : i32,3 : i32] : f32 into vector<4x8x16xf32>
|
||||
%1 = vector.insert %a, %b[3 : i32,3 : i32,3 : i32,3 : i32,3 : i32,3 : i32] : f32 into vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insert_element_vector_type(%a: vector<4xf32>, %b: vector<4x8x16xf32>) {
|
||||
func @insert_vector_type(%a: vector<4xf32>, %b: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute rank + source rank to match dest vector rank}}
|
||||
%1 = vector.insertelement %a, %b[3 : i32] : vector<4xf32> into vector<4x8x16xf32>
|
||||
%1 = vector.insert %a, %b[3 : i32] : vector<4xf32> into vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insert_element_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
|
||||
func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute rank to match the dest vector rank}}
|
||||
%1 = vector.insertelement %a, %b[3 : i32,3 : i32] : f32 into vector<4x8x16xf32>
|
||||
%1 = vector.insert %a, %b[3 : i32,3 : i32] : f32 into vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insertelement_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
|
||||
func @insert_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding dest vector dimension}}
|
||||
%1 = vector.insertelement %a, %b[0 : i32, 0 : i32, -1 : i32] : f32 into vector<4x8x16xf32>
|
||||
%1 = vector.insert %a, %b[0 : i32, 0 : i32, -1 : i32] : f32 into vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -35,25 +35,25 @@ func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: ve
|
|||
return %3 : vector<8x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @extractelement
|
||||
func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) {
|
||||
// CHECK: vector.extractelement {{.*}}[3 : i32] : vector<4x8x16xf32>
|
||||
%1 = vector.extractelement %arg0[3 : i32] : vector<4x8x16xf32>
|
||||
// CHECK-NEXT: vector.extractelement {{.*}}[3 : i32, 3 : i32] : vector<4x8x16xf32>
|
||||
%2 = vector.extractelement %arg0[3 : i32, 3 : i32] : vector<4x8x16xf32>
|
||||
// CHECK-NEXT: vector.extractelement {{.*}}[3 : i32, 3 : i32, 3 : i32] : vector<4x8x16xf32>
|
||||
%3 = vector.extractelement %arg0[3 : i32, 3 : i32, 3 : i32] : vector<4x8x16xf32>
|
||||
// CHECK-LABEL: @extract
|
||||
func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) {
|
||||
// CHECK: vector.extract {{.*}}[3 : i32] : vector<4x8x16xf32>
|
||||
%1 = vector.extract %arg0[3 : i32] : vector<4x8x16xf32>
|
||||
// CHECK-NEXT: vector.extract {{.*}}[3 : i32, 3 : i32] : vector<4x8x16xf32>
|
||||
%2 = vector.extract %arg0[3 : i32, 3 : i32] : vector<4x8x16xf32>
|
||||
// CHECK-NEXT: vector.extract {{.*}}[3 : i32, 3 : i32, 3 : i32] : vector<4x8x16xf32>
|
||||
%3 = vector.extract %arg0[3 : i32, 3 : i32, 3 : i32] : vector<4x8x16xf32>
|
||||
return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @insertelement
|
||||
func @insertelement(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) {
|
||||
// CHECK: vector.insertelement %{{.*}}, %{{.*}}[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
|
||||
%1 = vector.insertelement %c, %res[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
|
||||
// CHECK: vector.insertelement %{{.*}}, %{{.*}}[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32>
|
||||
%2 = vector.insertelement %b, %res[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32>
|
||||
// CHECK: vector.insertelement %{{.*}}, %{{.*}}[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32>
|
||||
%3 = vector.insertelement %a, %res[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32>
|
||||
// CHECK-LABEL: @insert
|
||||
func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) {
|
||||
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
|
||||
%1 = vector.insert %c, %res[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
|
||||
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32>
|
||||
%2 = vector.insert %b, %res[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32>
|
||||
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32>
|
||||
%3 = vector.insert %a, %res[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue