[mlir][VectorOps] Loosen restrictions on vector.reduction types

LLVM can deal with any integer or float type, don't arbitrarily restrict
it to f32/f64/i32/i64.

Differential Revision: https://reviews.llvm.org/D88010
This commit is contained in:
Benjamin Kramer 2020-09-21 12:04:33 +02:00
parent f4c5cadbcb
commit 2d76274b99
3 changed files with 25 additions and 6 deletions

View File

@ -561,7 +561,7 @@ public:
auto kind = reductionOp.kind();
Type eltType = reductionOp.dest().getType();
Type llvmType = typeConverter.convertType(eltType);
if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) {
if (eltType.isSignlessInteger()) {
// Integer reductions: add/mul/min/max/and/or/xor.
if (kind == "add")
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>(
@ -588,7 +588,7 @@ public:
return failure();
return success();
} else if (eltType.isF32() || eltType.isF64()) {
} else if (eltType.isa<FloatType>()) {
// Floating-point reductions: add/mul/min/max
if (kind == "add") {
// Optional accumulator (or zero).

View File

@ -132,11 +132,10 @@ static LogicalResult verify(ReductionOp op) {
auto kind = op.kind();
Type eltType = op.dest().getType();
if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") {
if (!eltType.isF32() && !eltType.isF64() &&
!eltType.isSignlessInteger(32) && !eltType.isSignlessInteger(64))
if (!eltType.isSignlessIntOrFloat())
return op.emitOpError("unsupported reduction type");
} else if (kind == "and" || kind == "or" || kind == "xor") {
if (!eltType.isSignlessInteger(32) && !eltType.isSignlessInteger(64))
if (!eltType.isSignlessInteger())
return op.emitOpError("unsupported reduction type");
} else {
return op.emitOpError("unknown reduction kind: ") << kind;
@ -146,7 +145,7 @@ static LogicalResult verify(ReductionOp op) {
if (!op.acc().empty()) {
if (kind != "add" && kind != "mul")
return op.emitOpError("no accumulator for reduction kind: ") << kind;
if (!eltType.isF32() && !eltType.isF64())
if (!eltType.isa<FloatType>())
return op.emitOpError("no accumulator for type: ") << eltType;
}

View File

@ -678,6 +678,17 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vect
return %0, %1: vector<8xf32>, vector<2x4xf32>
}
func @reduce_f16(%arg0: vector<16xf16>) -> f16 {
%0 = vector.reduction "add", %arg0 : vector<16xf16> into f16
return %0 : f16
}
// CHECK-LABEL: llvm.func @reduce_f16(
// CHECK-SAME: %[[A:.*]]: !llvm.vec<16 x half>)
// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f16) : !llvm.half
// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]])
// CHECK-SAME: {reassoc = false} : (!llvm.half, !llvm.vec<16 x half>) -> !llvm.half
// CHECK: llvm.return %[[V]] : !llvm.half
func @reduce_f32(%arg0: vector<16xf32>) -> f32 {
%0 = vector.reduction "add", %arg0 : vector<16xf32> into f32
return %0 : f32
@ -700,6 +711,15 @@ func @reduce_f64(%arg0: vector<16xf64>) -> f64 {
// CHECK-SAME: {reassoc = false} : (!llvm.double, !llvm.vec<16 x double>) -> !llvm.double
// CHECK: llvm.return %[[V]] : !llvm.double
func @reduce_i8(%arg0: vector<16xi8>) -> i8 {
%0 = vector.reduction "add", %arg0 : vector<16xi8> into i8
return %0 : i8
}
// CHECK-LABEL: llvm.func @reduce_i8(
// CHECK-SAME: %[[A:.*]]: !llvm.vec<16 x i8>)
// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]])
// CHECK: llvm.return %[[V]] : !llvm.i8
func @reduce_i32(%arg0: vector<16xi32>) -> i32 {
%0 = vector.reduction "add", %arg0 : vector<16xi32> into i32
return %0 : i32