forked from OSchip/llvm-project
Add vector.insertelement op
This is the counterpart of vector.extractelement op and has the same limitations at the moment (static I64IntegerArrayAttr to express position). This restriction will be filterd in the future. LLVM lowering will be added in a subsequent commit. PiperOrigin-RevId: 282365760
This commit is contained in:
parent
bf4692dc49
commit
01145544aa
|
@ -169,6 +169,40 @@ def Vector_ExtractElementOp :
|
|||
}];
|
||||
}
|
||||
|
||||
def Vector_InsertElementOp :
|
||||
Vector_Op<"insertelement", [NoSideEffect,
|
||||
PredOpTrait<"source operand and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>,
|
||||
PredOpTrait<"dest operand and result have same type",
|
||||
TCresIsSameAsOpBase<0, 1>>]>,
|
||||
Arguments<(ins AnyType:$source, AnyVector:$dest, I32ArrayAttr:$position)>,
|
||||
Results<(outs AnyVector)> {
|
||||
let summary = "insertelement operation";
|
||||
let description = [{
|
||||
Takes an n-D source vector, an (n+k)-D destination vector and a k-D position
|
||||
and inserts the n-D source into the (n+k)-D destination at the proper
|
||||
position. Degenerates to a scalar source type when n = 0.
|
||||
|
||||
Examples:
|
||||
```
|
||||
%2 = vector.insertelement %0, %1[3 : i32]:
|
||||
vector<8x16xf32> into vector<4x8x16xf32>
|
||||
%5 = vector.insertelement %3, %4[3 : i32, 3 : i32, 3 : i32]:
|
||||
f32 into vector<4x8x16xf32>
|
||||
```
|
||||
}];
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState &result, Value *source, " #
|
||||
"Value *dest, ArrayRef<int32_t>">];
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getPositionAttrName() { return "position"; }
|
||||
Type getSourceType() { return source()->getType(); }
|
||||
VectorType getDestVectorType() {
|
||||
return dest()->getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Vector_StridedSliceOp :
|
||||
Vector_Op<"strided_slice", [NoSideEffect,
|
||||
PredOpTrait<"operand and result have same element type",
|
||||
|
|
|
@ -1668,6 +1668,12 @@ class TCOpResIsShapedTypePred<int i, int j> : And<[
|
|||
SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()",
|
||||
IsShapedTypePred>]>;
|
||||
|
||||
// Predicate to verify that the i'th result and the j'th operand have the same
|
||||
// type.
|
||||
class TCresIsSameAsOpBase<int i, int j> :
|
||||
CPred<"$_op.getResult(" # i # ")->getType() == "
|
||||
"$_op.getOperand(" # j # ")->getType()">;
|
||||
|
||||
// Basic Predicate to verify that the i'th result and the j'th operand have the
|
||||
// same elemental type.
|
||||
class TCresVTEtIsSameAsOpBase<int i, int j> :
|
||||
|
|
|
@ -291,12 +291,84 @@ static LogicalResult verify(ExtractElementOp op) {
|
|||
attr.getInt() > op.getVectorType().getDimSize(en.index()))
|
||||
return op.emitOpError("expected position attribute #")
|
||||
<< (en.index() + 1)
|
||||
<< " to be a positive integer smaller than the corresponding "
|
||||
<< " to be a non-negative integer smaller than the corresponding "
|
||||
"vector dimension";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void InsertElementOp::build(Builder *builder, OperationState &result,
|
||||
Value *source, Value *dest,
|
||||
ArrayRef<int32_t> position) {
|
||||
result.addOperands({source, dest});
|
||||
auto positionAttr = builder->getI32ArrayAttr(position);
|
||||
result.addTypes(dest->getType());
|
||||
result.addAttribute(getPositionAttrName(), positionAttr);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, InsertElementOp op) {
|
||||
p << op.getOperationName() << " " << *op.source() << ", " << *op.dest()
|
||||
<< op.position();
|
||||
p.printOptionalAttrDict(op.getAttrs(),
|
||||
{InsertElementOp::getPositionAttrName()});
|
||||
p << " : " << op.getSourceType();
|
||||
p << " into " << op.getDestVectorType();
|
||||
}
|
||||
|
||||
static ParseResult parseInsertElementOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<NamedAttribute, 4> attrs;
|
||||
OpAsmParser::OperandType source, dest;
|
||||
Type sourceType;
|
||||
VectorType destType;
|
||||
Attribute attr;
|
||||
return failure(parser.parseOperand(source) || parser.parseComma() ||
|
||||
parser.parseOperand(dest) ||
|
||||
parser.parseAttribute(attr,
|
||||
InsertElementOp::getPositionAttrName(),
|
||||
result.attributes) ||
|
||||
parser.parseOptionalAttrDict(attrs) ||
|
||||
parser.parseColonType(sourceType) ||
|
||||
parser.parseKeywordType("into", destType) ||
|
||||
parser.resolveOperand(source, sourceType, result.operands) ||
|
||||
parser.resolveOperand(dest, destType, result.operands) ||
|
||||
parser.addTypeToList(destType, result.types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(InsertElementOp op) {
|
||||
auto positionAttr = op.position().getValue();
|
||||
if (positionAttr.empty())
|
||||
return op.emitOpError("expected non-empty position attribute");
|
||||
auto destVectorType = op.getDestVectorType();
|
||||
if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
|
||||
return op.emitOpError(
|
||||
"expected position attribute of rank smaller than dest vector rank");
|
||||
auto srcVectorType = op.getSourceType().dyn_cast<VectorType>();
|
||||
if (srcVectorType &&
|
||||
(static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
|
||||
static_cast<unsigned>(destVectorType.getRank())))
|
||||
return op.emitOpError("expected position attribute rank + source rank to "
|
||||
"match dest vector rank");
|
||||
else if (!srcVectorType && (positionAttr.size() !=
|
||||
static_cast<unsigned>(destVectorType.getRank())))
|
||||
return op.emitOpError(
|
||||
"expected position attribute rank to match the dest vector rank");
|
||||
for (auto en : llvm::enumerate(positionAttr)) {
|
||||
auto attr = en.value().dyn_cast<IntegerAttr>();
|
||||
if (!attr || attr.getInt() < 0 ||
|
||||
attr.getInt() > destVectorType.getDimSize(en.index()))
|
||||
return op.emitOpError("expected position attribute #")
|
||||
<< (en.index() + 1)
|
||||
<< " to be a non-negative integer smaller than the corresponding "
|
||||
"dest vector dimension";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StridedSliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -31,19 +31,54 @@ func @extractelement_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
|
|||
// -----
|
||||
|
||||
func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute #2 to be a positive integer smaller than the corresponding vector dimension}}
|
||||
// expected-error@+1 {{expected position attribute #2 to be a non-negative integer smaller than the corresponding vector dimension}}
|
||||
%1 = vector.extractelement %arg0[0 : i32, 43 : i32, 0 : i32] : vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute #3 to be a positive integer smaller than the corresponding vector dimension}}
|
||||
// expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
|
||||
%1 = vector.extractelement %arg0[0 : i32, 0 : i32, -1 : i32] : vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insert_element_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected non-empty position attribute}}
|
||||
%1 = vector.insertelement %a, %b[] : f32 into vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insert_element_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute of rank smaller than dest vector rank}}
|
||||
%1 = vector.insertelement %a, %b[3 : i32,3 : i32,3 : i32,3 : i32,3 : i32,3 : i32] : f32 into vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insert_element_vector_type(%a: vector<4xf32>, %b: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute rank + source rank to match dest vector rank}}
|
||||
%1 = vector.insertelement %a, %b[3 : i32] : vector<4xf32> into vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insert_element_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute rank to match the dest vector rank}}
|
||||
%1 = vector.insertelement %a, %b[3 : i32,3 : i32] : f32 into vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insertelement_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding dest vector dimension}}
|
||||
%1 = vector.insertelement %a, %b[0 : i32, 0 : i32, -1 : i32] : f32 into vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @outerproduct_num_operands(%arg0: f32) {
|
||||
// expected-error@+1 {{expected at least 2 operands}}
|
||||
%1 = vector.outerproduct %arg0 : f32, f32
|
||||
|
@ -369,5 +404,3 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
|||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -33,6 +33,17 @@ func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16x
|
|||
return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insertelement
|
||||
func @insertelement(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) {
|
||||
// CHECK: vector.insertelement %{{.*}}, %{{.*}}[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
|
||||
%1 = vector.insertelement %c, %res[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
|
||||
// CHECK: vector.insertelement %{{.*}}, %{{.*}}[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32>
|
||||
%2 = vector.insertelement %b, %res[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32>
|
||||
// CHECK: vector.insertelement %{{.*}}, %{{.*}}[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32>
|
||||
%3 = vector.insertelement %a, %res[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: outerproduct
|
||||
func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> {
|
||||
// CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>
|
||||
|
|
Loading…
Reference in New Issue