forked from OSchip/llvm-project
Add VectorOps.StridedSliceOp
The `vector.strided_slice` takes an n-D vector, k-D `offsets` integer array attribute, a k-D `sizes` integer array attribute, a k-D `strides` integer array attribute and extracts the n-D subvector at the proper offset. Returns an n-D vector where the first k-D dimensions match the `sizes` attribute. The returned subvector contains the elements starting at offset `offsets` and ending at `offsets + sizes`. Example: ``` %1 = vector.strided_slice %0 {offsets : [0, 2], sizes : [2, 4], strides : [1, 1]}: vector<4x8x16xf32> // returns a vector<2x4x16xf32> ``` This op will be useful for progressive lowering within the VectorOp dialect. PiperOrigin-RevId: 281352749
This commit is contained in:
parent
3732ba4def
commit
ee95f6f259
|
@ -76,6 +76,48 @@ def VectorExtractElementOp :
|
|||
}];
|
||||
}
|
||||
|
||||
def VectorStridedSliceOp :
|
||||
Vector_Op<"strided_slice", [NoSideEffect,
|
||||
PredOpTrait<"operand and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>]>,
|
||||
Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets,
|
||||
I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
|
||||
Results<(outs AnyVector)> {
|
||||
let summary = "strided_slice operation";
|
||||
let description = [{
|
||||
Takes an n-D vector, k-D `offsets` integer array attribute, a k-D `sizes`
|
||||
integer array attribute, a k-D `strides` integer array attribute and
|
||||
extracts the n-D subvector at the proper offset.
|
||||
|
||||
At the moment strides must contain only 1s.
|
||||
|
||||
Returns an n-D vector where the first k-D dimensions match the `sizes`
|
||||
attribute. The returned subvector contains the elements starting at offset
|
||||
`offsets` and ending at `offsets + sizes`.
|
||||
|
||||
Examples:
|
||||
```
|
||||
%1 = vector.strided_slice %0
|
||||
{offsets : [0, 2], sizes : [2, 4], strides : [1, 1]}:
|
||||
vector<4x8x16xf32> to vector<2x4x16xf32>
|
||||
```
|
||||
|
||||
// TODO(Evolve to a range form syntax):
|
||||
%1 = vector.strided_slice %0[0:2:1][2:4:1]
|
||||
vector<4x8x16xf32> to vector<2x4x16xf32>
|
||||
}];
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState &result, Value *source, " #
|
||||
"ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, " #
|
||||
"ArrayRef<int64_t> strides">];
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getOffsetsAttrName() { return "offsets"; }
|
||||
static StringRef getSizesAttrName() { return "sizes"; }
|
||||
static StringRef getStridesAttrName() { return "strides"; }
|
||||
VectorType getVectorType(){ return vector()->getType().cast<VectorType>(); }
|
||||
}];
|
||||
}
|
||||
|
||||
def VectorOuterProductOp :
|
||||
Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>,
|
||||
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, Variadic<AnyVector>:$acc)>,
|
||||
|
|
|
@ -92,7 +92,7 @@ static ParseResult parseVectorExtractElementOp(OpAsmParser &parser,
|
|||
static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
|
||||
return parser.emitError(
|
||||
attributeLoc,
|
||||
"expected position attribute of rank smaller than vector");
|
||||
"expected position attribute of rank smaller than vector rank");
|
||||
|
||||
Type resType = inferExtractOpResultType(vectorType, positionAttr);
|
||||
result.attributes = attrs;
|
||||
|
@ -106,7 +106,7 @@ static LogicalResult verify(VectorExtractElementOp op) {
|
|||
return op.emitOpError("expected non-empty position attribute");
|
||||
if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
|
||||
return op.emitOpError(
|
||||
"expected position attribute of rank smaller than vector");
|
||||
"expected position attribute of rank smaller than vector rank");
|
||||
for (auto en : llvm::enumerate(positionAttr)) {
|
||||
auto attr = en.value().dyn_cast<IntegerAttr>();
|
||||
if (!attr || attr.getInt() < 0 ||
|
||||
|
@ -119,6 +119,180 @@ static LogicalResult verify(VectorExtractElementOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VectorStridedSliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static Type inferVectorExtractRangeOpResultType(VectorType vectorType,
|
||||
ArrayAttr offsets,
|
||||
ArrayAttr sizes,
|
||||
ArrayAttr strides) {
|
||||
assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
|
||||
SmallVector<int64_t, 4> shape;
|
||||
shape.reserve(vectorType.getRank());
|
||||
unsigned idx = 0;
|
||||
for (unsigned e = offsets.size(); idx < e; ++idx)
|
||||
shape.push_back(sizes.getValue()[idx].cast<IntegerAttr>().getInt());
|
||||
for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
|
||||
shape.push_back(vectorType.getShape()[idx]);
|
||||
|
||||
return VectorType::get(shape, vectorType.getElementType());
|
||||
}
|
||||
|
||||
void VectorStridedSliceOp::build(Builder *builder, OperationState &result,
|
||||
Value *source, ArrayRef<int64_t> offsets,
|
||||
ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> strides) {
|
||||
result.addOperands(source);
|
||||
auto offsetsAttr = builder->getI64ArrayAttr(offsets);
|
||||
auto sizesAttr = builder->getI64ArrayAttr(sizes);
|
||||
auto stridesAttr = builder->getI64ArrayAttr(strides);
|
||||
result.addTypes(
|
||||
inferVectorExtractRangeOpResultType(source->getType().cast<VectorType>(),
|
||||
offsetsAttr, sizesAttr, stridesAttr));
|
||||
result.addAttribute(getOffsetsAttrName(), offsetsAttr);
|
||||
result.addAttribute(getSizesAttrName(), sizesAttr);
|
||||
result.addAttribute(getStridesAttrName(), stridesAttr);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, VectorStridedSliceOp op) {
|
||||
p << op.getOperationName() << " " << *op.vector();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.vector()->getType() << " to " << op.getResult()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseVectorStridedSliceOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
llvm::SMLoc attributeLoc, typeLoc;
|
||||
OpAsmParser::OperandType vector;
|
||||
VectorType vectorType, resultVectorType;
|
||||
return failure(parser.parseOperand(vector) ||
|
||||
parser.getCurrentLocation(&attributeLoc) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.getCurrentLocation(&typeLoc) ||
|
||||
parser.parseColonType(vectorType) ||
|
||||
parser.parseKeywordType("to", resultVectorType) ||
|
||||
parser.resolveOperand(vector, vectorType, result.operands) ||
|
||||
parser.addTypeToList(resultVectorType, result.types));
|
||||
}
|
||||
|
||||
// TODO(ntv) Should be moved to Tablegen Confined attributes.
|
||||
static bool isIntegerArrayAttrSmallerThanShape(VectorStridedSliceOp op,
|
||||
ArrayAttr arrayAttr,
|
||||
ShapedType shape,
|
||||
StringRef attrName) {
|
||||
if (arrayAttr.size() > static_cast<unsigned>(shape.getRank())) {
|
||||
op.emitOpError("expected ")
|
||||
<< attrName << " attribute of rank smaller than vector rank";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
|
||||
// interval. If `halfOpen` is true then the admissible interval is [min, max).
|
||||
// Otherwise, the admissible interval is [min, max].
|
||||
static bool isIntegerArrayAttrConfinedToRange(VectorStridedSliceOp op,
|
||||
ArrayAttr arrayAttr, int64_t min,
|
||||
int64_t max, StringRef attrName,
|
||||
bool halfOpen = true) {
|
||||
for (auto attr : arrayAttr) {
|
||||
auto val = attr.cast<IntegerAttr>().getInt();
|
||||
auto upper = max;
|
||||
if (!halfOpen)
|
||||
upper += 1;
|
||||
if (val < min || val >= upper) {
|
||||
op.emitOpError("expected ")
|
||||
<< attrName << " to be confined to [" << min << ", " << upper << ")";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
|
||||
// interval. If `halfOpen` is true then the admissible interval is [min, max).
|
||||
// Otherwise, the admissible interval is [min, max].
|
||||
static bool
|
||||
isIntegerArrayAttrConfinedToShape(VectorStridedSliceOp op, ArrayAttr arrayAttr,
|
||||
ShapedType shape, StringRef attrName,
|
||||
bool halfOpen = true, int64_t min = 0) {
|
||||
assert(arrayAttr.size() <= static_cast<unsigned>(shape.getRank()));
|
||||
for (auto it : llvm::zip(arrayAttr, shape.getShape())) {
|
||||
auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
|
||||
auto max = std::get<1>(it);
|
||||
if (!halfOpen)
|
||||
max += 1;
|
||||
if (val < min || val >= max) {
|
||||
op.emitOpError("expected ")
|
||||
<< attrName << " to be confined to [" << min << ", " << max << ")";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns true if all integers in `arrayAttr` are in the interval [min, max}.
|
||||
// interval. If `halfOpen` is true then the admissible interval is [min, max).
|
||||
// Otherwise, the admissible interval is [min, max].
|
||||
static bool isSumOfIntegerArrayAttrConfinedToShape(
|
||||
VectorStridedSliceOp op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
|
||||
ShapedType shape, StringRef attrName1, StringRef attrName2,
|
||||
bool halfOpen = true, int64_t min = 1) {
|
||||
assert(arrayAttr1.size() <= static_cast<unsigned>(shape.getRank()));
|
||||
assert(arrayAttr2.size() <= static_cast<unsigned>(shape.getRank()));
|
||||
for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape.getShape())) {
|
||||
auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
|
||||
auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
|
||||
auto max = std::get<2>(it);
|
||||
if (!halfOpen)
|
||||
max += 1;
|
||||
if (val1 + val2 < 0 || val1 + val2 >= max) {
|
||||
op.emitOpError("expected sum(")
|
||||
<< attrName1 << ", " << attrName2 << ") to be confined to [" << min
|
||||
<< ", " << max << ")";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static LogicalResult verify(VectorStridedSliceOp op) {
|
||||
auto type = op.getVectorType();
|
||||
auto offsets = op.offsets();
|
||||
auto sizes = op.sizes();
|
||||
auto strides = op.strides();
|
||||
if (offsets.size() != sizes.size() || offsets.size() != strides.size()) {
|
||||
op.emitOpError(
|
||||
"expected offsets, sizes and strides attributes of same size");
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto offName = VectorStridedSliceOp::getOffsetsAttrName();
|
||||
auto sizesName = VectorStridedSliceOp::getSizesAttrName();
|
||||
auto stridesName = VectorStridedSliceOp::getStridesAttrName();
|
||||
if (!isIntegerArrayAttrSmallerThanShape(op, offsets, type, offName) ||
|
||||
!isIntegerArrayAttrSmallerThanShape(op, sizes, type, sizesName) ||
|
||||
!isIntegerArrayAttrSmallerThanShape(op, strides, type, stridesName) ||
|
||||
!isIntegerArrayAttrConfinedToShape(op, offsets, type, offName) ||
|
||||
!isIntegerArrayAttrConfinedToShape(op, sizes, type, sizesName,
|
||||
/*halfOpen=*/false, /*min=*/1) ||
|
||||
!isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
|
||||
/*halfOpen=*/false) ||
|
||||
!isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, type, offName,
|
||||
sizesName, /*halfOpen=*/false))
|
||||
return failure();
|
||||
|
||||
auto resultType = inferVectorExtractRangeOpResultType(
|
||||
op.getVectorType(), op.offsets(), op.sizes(), op.strides());
|
||||
if (op.getResult()->getType() != resultType) {
|
||||
op.emitOpError("expected result type to be ") << resultType;
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VectorOuterProductOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -231,6 +231,7 @@ func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
|
|||
// expected-error@+1 {{requires a projected permutation_map (at most one dim or the zero constant can appear in each result)}}
|
||||
vector.transfer_write %cst, %arg0[%c3, %c3] {permutation_map = (d0, d1)->(d0 + 1)} : vector<128xf32>, memref<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_vector.transfer_write(%arg0: memref<?x?x?xf32>) {
|
||||
|
@ -239,3 +240,66 @@ func @test_vector.transfer_write(%arg0: memref<?x?x?xf32>) {
|
|||
// expected-error@+1 {{requires a permutation_map that is a permutation (found one dim used more than once)}}
|
||||
vector.transfer_write %cst, %arg0[%c3, %c3, %c3] {permutation_map = (d0, d1, d2)->(d0, d0)} : vector<3x7xf32>, memref<?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @strided_slice(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected offsets, sizes and strides attributes of same size}}
|
||||
%1 = vector.strided_slice %arg0 {offsets = [100], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @strided_slice(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected offsets attribute of rank smaller than vector rank}}
|
||||
%1 = vector.strided_slice %arg0 {offsets = [2, 2, 2, 2], sizes = [2, 2, 2, 2], strides = [1, 1, 1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @strided_slice(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected offsets attribute of rank smaller than vector rank}}
|
||||
%1 = vector.strided_slice %arg0 {offsets = [2, 2, 2, 2], sizes = [2, 2, 2, 2], strides = [1, 1, 1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @strided_slice(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{op expected offsets to be confined to [0, 4)}}
|
||||
%1 = vector.strided_slice %arg0 {offsets = [100], sizes = [100], strides = [100]} : vector<4x8x16xf32> to vector<100x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @strided_slice(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{op expected sizes to be confined to [1, 5)}}
|
||||
%1 = vector.strided_slice %arg0 {offsets = [2], sizes = [100], strides = [100]} : vector<4x8x16xf32> to vector<100x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @strided_slice(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{op expected strides to be confined to [1, 2)}}
|
||||
%1 = vector.strided_slice %arg0 {offsets = [2], sizes = [1], strides = [100]} : vector<4x8x16xf32> to vector<1x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @strided_slice(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{op expected strides to be confined to [1, 2)}}
|
||||
%1 = vector.strided_slice %arg0 {offsets = [2], sizes = [1], strides = [100]} : vector<4x8x16xf32> to vector<1x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @strided_slice(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{op expected sum(offsets, sizes) to be confined to [1, 5)}}
|
||||
%1 = vector.strided_slice %arg0 {offsets = [2], sizes = [3], strides = [1]} : vector<4x8x16xf32> to vector<3x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @strided_slice(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{op expected result type to be 'vector<2x8x16xf32>'}}
|
||||
%1 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8x16xf32> to vector<3x1xf32>
|
||||
}
|
||||
|
|
|
@ -41,3 +41,10 @@ func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8
|
|||
%1 = vector.outerproduct %arg0, %arg1, %arg2 : vector<4xf32>, vector<8xf32>
|
||||
return %1 : vector<4x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: strided_slice
|
||||
func @strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
|
||||
// CHECK: vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32>
|
||||
%1 = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
|
||||
return %1: vector<2x2x16xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue