forked from OSchip/llvm-project
[mlir][vector] Add vector.bitcast operation
Based on the RFC discussed here: https://llvm.discourse.group/t/rfc-vector-standard-add-bitcast-operation/1628/ Adding a vector.bitcast operation that allows casting to a vector of different element type. The most minor dimension bitwidth must stay unchanged. Differential Revision: https://reviews.llvm.org/D86580
This commit is contained in:
parent
5d989fb37d
commit
5fbfe2ec4f
|
@ -1525,6 +1525,41 @@ def Vector_ShapeCastOp :
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Vector_BitCastOp :
|
||||
Vector_Op<"bitcast", [NoSideEffect, AllRanksMatch<["source", "result"]>]>,
|
||||
Arguments<(ins AnyVector:$source)>,
|
||||
Results<(outs AnyVector:$result)>{
|
||||
let summary = "bitcast casts between vectors";
|
||||
let description = [{
|
||||
The bitcast operation casts between vectors of the same rank, the minor 1-D
|
||||
vector size is casted to a vector with a different element type but same
|
||||
bitwidth.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Example casting to a smaller element type.
|
||||
%1 = vector.bitcast %0 : vector<5x1x4x3xf32> to vector<5x1x4x6xi16>
|
||||
|
||||
// Example casting to a bigger element type.
|
||||
%3 = vector.bitcast %2 : vector<10x12x8xi8> to vector<10x12x2xi32>
|
||||
|
||||
// Example casting to an element type of the same size.
|
||||
%5 = vector.bitcast %4 : vector<5x1x4x3xf32> to vector<5x1x4x3xi32>
|
||||
```
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
VectorType getSourceVectorType() {
|
||||
return source().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getResultVectorType() {
|
||||
return getResult().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Vector_TypeCastOp :
|
||||
Vector_Op<"type_cast", [NoSideEffect]>,
|
||||
Arguments<(ins StaticShapeMemRefOf<[AnyType]>:$memref)>,
|
||||
|
|
|
@ -2300,6 +2300,42 @@ OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
|
|||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VectorBitCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(BitCastOp op) {
|
||||
auto sourceVectorType = op.getSourceVectorType();
|
||||
auto resultVectorType = op.getResultVectorType();
|
||||
|
||||
for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
|
||||
if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
|
||||
return op.emitOpError("dimension size mismatch at: ") << i;
|
||||
}
|
||||
|
||||
if (sourceVectorType.getElementTypeBitWidth() *
|
||||
sourceVectorType.getShape().back() !=
|
||||
resultVectorType.getElementTypeBitWidth() *
|
||||
resultVectorType.getShape().back())
|
||||
return op.emitOpError(
|
||||
"source/result bitwidth of the minor 1-D vectors must be equal");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Nop cast.
|
||||
if (source().getType() == result().getType())
|
||||
return source();
|
||||
|
||||
// Canceling bitcasts.
|
||||
if (auto otherOp = source().getDefiningOp<BitCastOp>())
|
||||
if (result().getType() == otherOp.source().getType())
|
||||
return otherOp.source();
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -372,3 +372,16 @@ func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vector<4x9
|
|||
// CHECK: return
|
||||
return %1, %2 : vector<4x8xf32>, vector<4x9xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: bitcast_folding
|
||||
// CHECK-SAME: %[[A:.*]]: vector<4x8xf32>
|
||||
// CHECK-SAME: %[[B:.*]]: vector<2xi32>
|
||||
// CHECK: return %[[A]], %[[B]] : vector<4x8xf32>, vector<2xi32>
|
||||
func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<4x8xf32>, vector<2xi32>) {
|
||||
%0 = vector.bitcast %I1 : vector<4x8xf32> to vector<4x8xf32>
|
||||
%1 = vector.bitcast %I2 : vector<2xi32> to vector<4xi16>
|
||||
%2 = vector.bitcast %1 : vector<4xi16> to vector<2xi32>
|
||||
return %0, %2 : vector<4x8xf32>, vector<2xi32>
|
||||
}
|
||||
|
|
|
@ -1065,6 +1065,34 @@ func @shape_cast_different_tuple_sizes(
|
|||
|
||||
// -----
|
||||
|
||||
func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) {
|
||||
// expected-error@+1 {{must be vector of any type values}}
|
||||
%0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @bitcast_rank_mismatch(%arg0 : vector<5x1x3x2xf32>) {
|
||||
// expected-error@+1 {{op failed to verify that all of {source, result} have same rank}}
|
||||
%0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x3x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @bitcast_shape_mismatch(%arg0 : vector<5x1x3x2xf32>) {
|
||||
// expected-error@+1 {{op dimension size mismatch}}
|
||||
%0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x2x3x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @bitcast_sizemismatch(%arg0 : vector<5x1x3x2xf32>) {
|
||||
// expected-error@+1 {{op source/result bitwidth of the minor 1-D vectors must be equal}}
|
||||
%0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x1x3x3xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @reduce_unknown_kind(%arg0: vector<16xf32>) -> f32 {
|
||||
// expected-error@+1 {{'vector.reduction' op unknown reduction kind: joho}}
|
||||
%0 = vector.reduction "joho", %arg0 : vector<16xf32> into f32
|
||||
|
|
|
@ -298,6 +298,33 @@ func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
|
|||
return %0, %1, %2, %3, %4 : vector<15x2xf32>, tuple<vector<20x2xf32>, vector<12x2xf32>>, vector<8xf32>, vector<16xf32>, vector<16x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @bitcast
|
||||
func @bitcast(%arg0 : vector<5x1x3x2xf32>,
|
||||
%arg1 : vector<8x1xi32>,
|
||||
%arg2 : vector<16x1x8xi8>)
|
||||
-> (vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>) {
|
||||
|
||||
// CHECK: vector.bitcast %{{.*}} : vector<5x1x3x2xf32> to vector<5x1x3x4xf16>
|
||||
%0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x1x3x4xf16>
|
||||
|
||||
// CHECK-NEXT: vector.bitcast %{{.*}} : vector<5x1x3x2xf32> to vector<5x1x3x8xi8>
|
||||
%1 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x1x3x8xi8>
|
||||
|
||||
// CHECK-NEXT: vector.bitcast %{{.*}} : vector<8x1xi32> to vector<8x4xi8>
|
||||
%2 = vector.bitcast %arg1 : vector<8x1xi32> to vector<8x4xi8>
|
||||
|
||||
// CHECK-NEXT: vector.bitcast %{{.*}} : vector<8x1xi32> to vector<8x1xf32>
|
||||
%3 = vector.bitcast %arg1 : vector<8x1xi32> to vector<8x1xf32>
|
||||
|
||||
// CHECK-NEXT: vector.bitcast %{{.*}} : vector<16x1x8xi8> to vector<16x1x2xi32>
|
||||
%4 = vector.bitcast %arg2 : vector<16x1x8xi8> to vector<16x1x2xi32>
|
||||
|
||||
// CHECK-NEXT: vector.bitcast %{{.*}} : vector<16x1x8xi8> to vector<16x1x4xi16>
|
||||
%5 = vector.bitcast %arg2 : vector<16x1x8xi8> to vector<16x1x4xi16>
|
||||
|
||||
return %0, %1, %2, %3, %4, %5 : vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @vector_fma
|
||||
func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>) {
|
||||
// CHECK: vector.fma %{{.*}} : vector<8xf32>
|
||||
|
|
Loading…
Reference in New Issue