Add vector.insertelement op

This is the counterpart of vector.extractelement op and has the same
limitations at the moment (static I64IntegerArrayAttr to express position).
This restriction will be filterd in the future.
LLVM lowering will be added in a subsequent commit.

PiperOrigin-RevId: 282365760
This commit is contained in:
Nicolas Vasilache 2019-11-25 08:46:37 -08:00 committed by A. Unique TensorFlower
parent bf4692dc49
commit 01145544aa
5 changed files with 161 additions and 5 deletions

View File

@ -169,6 +169,40 @@ def Vector_ExtractElementOp :
}];
}
def Vector_InsertElementOp :
Vector_Op<"insertelement", [NoSideEffect,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"dest operand and result have same type",
TCresIsSameAsOpBase<0, 1>>]>,
Arguments<(ins AnyType:$source, AnyVector:$dest, I32ArrayAttr:$position)>,
Results<(outs AnyVector)> {
let summary = "insertelement operation";
let description = [{
Takes an n-D source vector, an (n+k)-D destination vector and a k-D position
and inserts the n-D source into the (n+k)-D destination at the proper
position. Degenerates to a scalar source type when n = 0.
Examples:
```
%2 = vector.insertelement %0, %1[3 : i32]:
vector<8x16xf32> into vector<4x8x16xf32>
%5 = vector.insertelement %3, %4[3 : i32, 3 : i32, 3 : i32]:
f32 into vector<4x8x16xf32>
```
}];
let builders = [OpBuilder<
"Builder *builder, OperationState &result, Value *source, " #
"Value *dest, ArrayRef<int32_t>">];
let extraClassDeclaration = [{
static StringRef getPositionAttrName() { return "position"; }
Type getSourceType() { return source()->getType(); }
VectorType getDestVectorType() {
return dest()->getType().cast<VectorType>();
}
}];
}
def Vector_StridedSliceOp :
Vector_Op<"strided_slice", [NoSideEffect,
PredOpTrait<"operand and result have same element type",

View File

@ -1668,6 +1668,12 @@ class TCOpResIsShapedTypePred<int i, int j> : And<[
SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()",
IsShapedTypePred>]>;
// Predicate to verify that the i'th result and the j'th operand have the same
// type.
class TCresIsSameAsOpBase<int i, int j> :
CPred<"$_op.getResult(" # i # ")->getType() == "
"$_op.getOperand(" # j # ")->getType()">;
// Basic Predicate to verify that the i'th result and the j'th operand have the
// same elemental type.
class TCresVTEtIsSameAsOpBase<int i, int j> :

View File

@ -291,12 +291,84 @@ static LogicalResult verify(ExtractElementOp op) {
attr.getInt() > op.getVectorType().getDimSize(en.index()))
return op.emitOpError("expected position attribute #")
<< (en.index() + 1)
<< " to be a positive integer smaller than the corresponding "
<< " to be a non-negative integer smaller than the corresponding "
"vector dimension";
}
return success();
}
//===----------------------------------------------------------------------===//
// InsertElementOp
//===----------------------------------------------------------------------===//
void InsertElementOp::build(Builder *builder, OperationState &result,
Value *source, Value *dest,
ArrayRef<int32_t> position) {
result.addOperands({source, dest});
auto positionAttr = builder->getI32ArrayAttr(position);
result.addTypes(dest->getType());
result.addAttribute(getPositionAttrName(), positionAttr);
}
static void print(OpAsmPrinter &p, InsertElementOp op) {
p << op.getOperationName() << " " << *op.source() << ", " << *op.dest()
<< op.position();
p.printOptionalAttrDict(op.getAttrs(),
{InsertElementOp::getPositionAttrName()});
p << " : " << op.getSourceType();
p << " into " << op.getDestVectorType();
}
static ParseResult parseInsertElementOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType source, dest;
Type sourceType;
VectorType destType;
Attribute attr;
return failure(parser.parseOperand(source) || parser.parseComma() ||
parser.parseOperand(dest) ||
parser.parseAttribute(attr,
InsertElementOp::getPositionAttrName(),
result.attributes) ||
parser.parseOptionalAttrDict(attrs) ||
parser.parseColonType(sourceType) ||
parser.parseKeywordType("into", destType) ||
parser.resolveOperand(source, sourceType, result.operands) ||
parser.resolveOperand(dest, destType, result.operands) ||
parser.addTypeToList(destType, result.types));
}
static LogicalResult verify(InsertElementOp op) {
auto positionAttr = op.position().getValue();
if (positionAttr.empty())
return op.emitOpError("expected non-empty position attribute");
auto destVectorType = op.getDestVectorType();
if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
return op.emitOpError(
"expected position attribute of rank smaller than dest vector rank");
auto srcVectorType = op.getSourceType().dyn_cast<VectorType>();
if (srcVectorType &&
(static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
static_cast<unsigned>(destVectorType.getRank())))
return op.emitOpError("expected position attribute rank + source rank to "
"match dest vector rank");
else if (!srcVectorType && (positionAttr.size() !=
static_cast<unsigned>(destVectorType.getRank())))
return op.emitOpError(
"expected position attribute rank to match the dest vector rank");
for (auto en : llvm::enumerate(positionAttr)) {
auto attr = en.value().dyn_cast<IntegerAttr>();
if (!attr || attr.getInt() < 0 ||
attr.getInt() > destVectorType.getDimSize(en.index()))
return op.emitOpError("expected position attribute #")
<< (en.index() + 1)
<< " to be a non-negative integer smaller than the corresponding "
"dest vector dimension";
}
return success();
}
//===----------------------------------------------------------------------===//
// StridedSliceOp
//===----------------------------------------------------------------------===//

View File

@ -31,19 +31,54 @@ func @extractelement_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
// -----
func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute #2 to be a positive integer smaller than the corresponding vector dimension}}
// expected-error@+1 {{expected position attribute #2 to be a non-negative integer smaller than the corresponding vector dimension}}
%1 = vector.extractelement %arg0[0 : i32, 43 : i32, 0 : i32] : vector<4x8x16xf32>
}
// -----
func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute #3 to be a positive integer smaller than the corresponding vector dimension}}
// expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
%1 = vector.extractelement %arg0[0 : i32, 0 : i32, -1 : i32] : vector<4x8x16xf32>
}
// -----
func @insert_element_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
// expected-error@+1 {{expected non-empty position attribute}}
%1 = vector.insertelement %a, %b[] : f32 into vector<4x8x16xf32>
}
// -----
func @insert_element_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute of rank smaller than dest vector rank}}
%1 = vector.insertelement %a, %b[3 : i32,3 : i32,3 : i32,3 : i32,3 : i32,3 : i32] : f32 into vector<4x8x16xf32>
}
// -----
func @insert_element_vector_type(%a: vector<4xf32>, %b: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute rank + source rank to match dest vector rank}}
%1 = vector.insertelement %a, %b[3 : i32] : vector<4xf32> into vector<4x8x16xf32>
}
// -----
func @insert_element_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute rank to match the dest vector rank}}
%1 = vector.insertelement %a, %b[3 : i32,3 : i32] : f32 into vector<4x8x16xf32>
}
// -----
func @insertelement_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding dest vector dimension}}
%1 = vector.insertelement %a, %b[0 : i32, 0 : i32, -1 : i32] : f32 into vector<4x8x16xf32>
}
// -----
func @outerproduct_num_operands(%arg0: f32) {
// expected-error@+1 {{expected at least 2 operands}}
%1 = vector.outerproduct %arg0 : f32, f32
@ -369,5 +404,3 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
return
}

View File

@ -33,6 +33,17 @@ func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16x
return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32
}
// CHECK-LABEL: insertelement
func @insertelement(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) {
// CHECK: vector.insertelement %{{.*}}, %{{.*}}[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
%1 = vector.insertelement %c, %res[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
// CHECK: vector.insertelement %{{.*}}, %{{.*}}[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32>
%2 = vector.insertelement %b, %res[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32>
// CHECK: vector.insertelement %{{.*}}, %{{.*}}[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32>
%3 = vector.insertelement %a, %res[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32>
return
}
// CHECK-LABEL: outerproduct
func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> {
// CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>