forked from OSchip/llvm-project
[mlir][Vector] Thread 0-d vectors through ExtractElementOp.
This revision starts making concrete use of 0-d vectors to extend the semantics of ExtractElementOp. In the process a new VectorOfAnyRank Tablegen OpBase.td is added to allow progressive transition to supporting 0-d vectors by gradually opting in. Differential Revision: https://reviews.llvm.org/D114387
This commit is contained in:
parent
f24d9313cc
commit
e7026aba00
|
@ -482,14 +482,20 @@ def Vector_ExtractElementOp :
|
|||
TypesMatchWith<"result type matches element type of vector operand",
|
||||
"vector", "result",
|
||||
"$_self.cast<ShapedType>().getElementType()">]>,
|
||||
Arguments<(ins AnyVector:$vector, AnySignlessIntegerOrIndex:$position)>,
|
||||
Arguments<(ins AnyVectorOfAnyRank:$vector,
|
||||
Optional<AnySignlessIntegerOrIndex>:$position)>,
|
||||
Results<(outs AnyType:$result)> {
|
||||
let summary = "extractelement operation";
|
||||
let description = [{
|
||||
Takes an 1-D vector and a dynamic index position and extracts the
|
||||
scalar at that position. Note that this instruction resembles
|
||||
vector.extract, but is restricted to 1-D vectors and relaxed
|
||||
to dynamic indices. It is meant to be closer to LLVM's version:
|
||||
Takes a 0-D or 1-D vector and a optional dynamic index position and
|
||||
extracts the scalar at that position.
|
||||
|
||||
Note that this instruction resembles vector.extract, but is restricted to
|
||||
0-D and 1-D vectors and relaxed to dynamic indices.
|
||||
If the vector is 0-D, the position must be llvm::None.
|
||||
|
||||
|
||||
It is meant to be closer to LLVM's version:
|
||||
https://llvm.org/docs/LangRef.html#extractelement-instruction
|
||||
|
||||
Example:
|
||||
|
@ -497,14 +503,18 @@ def Vector_ExtractElementOp :
|
|||
```mlir
|
||||
%c = arith.constant 15 : i32
|
||||
%1 = vector.extractelement %0[%c : i32]: vector<16xf32>
|
||||
%2 = vector.extractelement %z[]: vector<f32>
|
||||
```
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
$vector `[` $position `:` type($position) `]` attr-dict `:` type($vector)
|
||||
$vector `[` ($position^ `:` type($position))? `]` attr-dict `:` type($vector)
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$source, "Value":$position)>
|
||||
// 0-D builder.
|
||||
OpBuilder<(ins "Value":$source)>,
|
||||
// 1-D + position builder.
|
||||
OpBuilder<(ins "Value":$source, "Value":$position)>,
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
VectorType getVectorType() {
|
||||
|
|
|
@ -208,7 +208,12 @@ class SuccessorConstraint<Pred predicate, string summary = ""> :
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Whether a type is a VectorType.
|
||||
def IsVectorTypePred : CPred<"$_self.isa<::mlir::VectorType>()">;
|
||||
// Explicitly disallow 0-D vectors for now until we have good enough coverage.
|
||||
def IsVectorTypePred : And<[CPred<"$_self.isa<::mlir::VectorType>()">,
|
||||
CPred<"$_self.cast<::mlir::VectorType>().getRank() > 0">]>;
|
||||
|
||||
// Temporary vector type clone that allows gradual transition to 0-D vectors.
|
||||
def IsVectorOfAnyRankTypePred : CPred<"$_self.isa<::mlir::VectorType>()">;
|
||||
|
||||
// Whether a type is a TensorType.
|
||||
def IsTensorTypePred : CPred<"$_self.isa<::mlir::TensorType>()">;
|
||||
|
@ -598,6 +603,10 @@ class HasAnyRankOfPred<list<int> ranks> : And<[
|
|||
class VectorOf<list<Type> allowedTypes> :
|
||||
ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
|
||||
"::mlir::VectorType">;
|
||||
// Temporary vector type clone that allows gradual transition to 0-D vectors.
|
||||
class VectorOfAnyRankOf<list<Type> allowedTypes> :
|
||||
ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
|
||||
"::mlir::VectorType">;
|
||||
|
||||
// Whether the number of elements of a vector is from the given
|
||||
// `allowedRanks` list
|
||||
|
@ -649,6 +658,8 @@ class VectorOfLengthAndType<list<int> allowedLengths,
|
|||
"::mlir::VectorType">;
|
||||
|
||||
def AnyVector : VectorOf<[AnyType]>;
|
||||
// Temporary vector type clone that allows gradual transition to 0-D vectors.
|
||||
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
|
||||
|
||||
// Shaped types.
|
||||
|
||||
|
|
|
@ -369,13 +369,17 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
|
|||
return LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt());
|
||||
}
|
||||
|
||||
/// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type
|
||||
/// when n > 1. For example, `vector<4 x f32>` remains as is while,
|
||||
/// `vector<4x8x16xf32>` converts to `!llvm.array<4xarray<8 x vector<16xf32>>>`.
|
||||
/// Convert an n-D vector type to an LLVM vector type:
|
||||
/// * 0-D `vector<T>` are converted to vector<1xT>
|
||||
/// * 1-D `vector<axT>` remains as is while,
|
||||
/// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
|
||||
/// `!llvm.array<ax...array<jxvector<kxT>>>`.
|
||||
Type LLVMTypeConverter::convertVectorType(VectorType type) {
|
||||
auto elementType = convertType(type.getElementType());
|
||||
if (!elementType)
|
||||
return {};
|
||||
if (type.getShape().empty())
|
||||
return VectorType::get({1}, elementType);
|
||||
Type vectorType = VectorType::get(type.getShape().back(), elementType);
|
||||
assert(LLVM::isCompatibleVectorType(vectorType) &&
|
||||
"expected vector type compatible with the LLVM dialect");
|
||||
|
|
|
@ -40,6 +40,7 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
|
|||
LLVMTypeConverter &typeConverter, Location loc,
|
||||
Value val1, Value val2, Type llvmType, int64_t rank,
|
||||
int64_t pos) {
|
||||
assert(rank > 0 && "0-D vector corner case should have been handled already");
|
||||
if (rank == 1) {
|
||||
auto idxType = rewriter.getIndexType();
|
||||
auto constant = rewriter.create<LLVM::ConstantOp>(
|
||||
|
@ -56,6 +57,7 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
|
|||
static Value extractOne(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter, Location loc,
|
||||
Value val, Type llvmType, int64_t rank, int64_t pos) {
|
||||
assert(rank > 0 && "0-D vector corner case should have been handled already");
|
||||
if (rank == 1) {
|
||||
auto idxType = rewriter.getIndexType();
|
||||
auto constant = rewriter.create<LLVM::ConstantOp>(
|
||||
|
@ -542,6 +544,17 @@ public:
|
|||
if (!llvmType)
|
||||
return failure();
|
||||
|
||||
if (vectorType.getRank() == 0) {
|
||||
Location loc = extractEltOp.getLoc();
|
||||
auto idxType = rewriter.getIndexType();
|
||||
auto zero = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, typeConverter->convertType(idxType),
|
||||
rewriter.getIntegerAttr(idxType, 0));
|
||||
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
|
||||
extractEltOp, llvmType, adaptor.vector(), zero);
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
|
||||
extractEltOp, llvmType, adaptor.vector(), adaptor.position());
|
||||
return success();
|
||||
|
|
|
@ -832,6 +832,12 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
// ExtractElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value source) {
|
||||
result.addOperands({source});
|
||||
result.addTypes(source.getType().cast<VectorType>().getElementType());
|
||||
}
|
||||
|
||||
void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value source, Value position) {
|
||||
result.addOperands({source, position});
|
||||
|
@ -840,8 +846,15 @@ void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
|
|||
|
||||
static LogicalResult verify(vector::ExtractElementOp op) {
|
||||
VectorType vectorType = op.getVectorType();
|
||||
if (vectorType.getRank() == 0) {
|
||||
if (op.position())
|
||||
return op.emitOpError("expected position to be empty with 0-D vector");
|
||||
return success();
|
||||
}
|
||||
if (vectorType.getRank() != 1)
|
||||
return op.emitOpError("expected 1-D vector");
|
||||
return op.emitOpError("unexpected >1 vector rank");
|
||||
if (!op.position())
|
||||
return op.emitOpError("expected position for 1-D vector");
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -418,6 +418,16 @@ func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @extract_element_0d
|
||||
func @extract_element_0d(%a: vector<f32>) -> f32 {
|
||||
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
|
||||
// CHECK: llvm.extractelement %{{.*}}[%[[C0]] : {{.*}}] : vector<1xf32>
|
||||
%1 = vector.extractelement %a[] : vector<f32>
|
||||
return %1 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extract_element(%arg0: vector<16xf32>) -> f32 {
|
||||
%0 = arith.constant 15 : i32
|
||||
%1 = vector.extractelement %arg0[%0 : i32]: vector<16xf32>
|
||||
|
|
|
@ -72,9 +72,25 @@ func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @extract_element(%arg0: vector<f32>) {
|
||||
%c = arith.constant 3 : i32
|
||||
// expected-error@+1 {{expected position to be empty with 0-D vector}}
|
||||
%1 = vector.extractelement %arg0[%c : i32] : vector<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extract_element(%arg0: vector<4xf32>) {
|
||||
%c = arith.constant 3 : i32
|
||||
// expected-error@+1 {{expected position for 1-D vector}}
|
||||
%1 = vector.extractelement %arg0[] : vector<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extract_element(%arg0: vector<4x4xf32>) {
|
||||
%c = arith.constant 3 : i32
|
||||
// expected-error@+1 {{'vector.extractelement' op expected 1-D vector}}
|
||||
// expected-error@+1 {{unexpected >1 vector rank}}
|
||||
%1 = vector.extractelement %arg0[%c : i32] : vector<4x4xf32>
|
||||
}
|
||||
|
||||
|
|
|
@ -163,6 +163,13 @@ func @shuffle2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> {
|
|||
return %1 : vector<3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @extract_element_0d
|
||||
func @extract_element_0d(%a: vector<f32>) -> f32 {
|
||||
// CHECK-NEXT: vector.extractelement %{{.*}}[] : vector<f32>
|
||||
%1 = vector.extractelement %a[] : vector<f32>
|
||||
return %1 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @extract_element
|
||||
func @extract_element(%a: vector<16xf32>) -> f32 {
|
||||
// CHECK: %[[C15:.*]] = arith.constant 15 : i32
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \
|
||||
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
|
||||
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
|
||||
// RUN: FileCheck %s
|
||||
|
||||
func @extract_element_0d(%a: vector<f32>) {
|
||||
%1 = vector.extractelement %a[] : vector<f32>
|
||||
// CHECK: 42
|
||||
vector.print %1: f32
|
||||
return
|
||||
}
|
||||
|
||||
func @entry() {
|
||||
%1 = arith.constant dense<42.0> : vector<f32>
|
||||
call @extract_element_0d(%1) : (vector<f32>) -> ()
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue