forked from OSchip/llvm-project
[mlir][VectorOps] Loosen restrictions on vector.reduction types
LLVM can deal with any integer or float type, don't arbitrarily restrict it to f32/f64/i32/i64. Differential Revision: https://reviews.llvm.org/D88010
This commit is contained in:
parent
f4c5cadbcb
commit
2d76274b99
|
@ -561,7 +561,7 @@ public:
|
|||
auto kind = reductionOp.kind();
|
||||
Type eltType = reductionOp.dest().getType();
|
||||
Type llvmType = typeConverter.convertType(eltType);
|
||||
if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) {
|
||||
if (eltType.isSignlessInteger()) {
|
||||
// Integer reductions: add/mul/min/max/and/or/xor.
|
||||
if (kind == "add")
|
||||
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>(
|
||||
|
@ -588,7 +588,7 @@ public:
|
|||
return failure();
|
||||
return success();
|
||||
|
||||
} else if (eltType.isF32() || eltType.isF64()) {
|
||||
} else if (eltType.isa<FloatType>()) {
|
||||
// Floating-point reductions: add/mul/min/max
|
||||
if (kind == "add") {
|
||||
// Optional accumulator (or zero).
|
||||
|
|
|
@ -132,11 +132,10 @@ static LogicalResult verify(ReductionOp op) {
|
|||
auto kind = op.kind();
|
||||
Type eltType = op.dest().getType();
|
||||
if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") {
|
||||
if (!eltType.isF32() && !eltType.isF64() &&
|
||||
!eltType.isSignlessInteger(32) && !eltType.isSignlessInteger(64))
|
||||
if (!eltType.isSignlessIntOrFloat())
|
||||
return op.emitOpError("unsupported reduction type");
|
||||
} else if (kind == "and" || kind == "or" || kind == "xor") {
|
||||
if (!eltType.isSignlessInteger(32) && !eltType.isSignlessInteger(64))
|
||||
if (!eltType.isSignlessInteger())
|
||||
return op.emitOpError("unsupported reduction type");
|
||||
} else {
|
||||
return op.emitOpError("unknown reduction kind: ") << kind;
|
||||
|
@ -146,7 +145,7 @@ static LogicalResult verify(ReductionOp op) {
|
|||
if (!op.acc().empty()) {
|
||||
if (kind != "add" && kind != "mul")
|
||||
return op.emitOpError("no accumulator for reduction kind: ") << kind;
|
||||
if (!eltType.isF32() && !eltType.isF64())
|
||||
if (!eltType.isa<FloatType>())
|
||||
return op.emitOpError("no accumulator for type: ") << eltType;
|
||||
}
|
||||
|
||||
|
|
|
@ -678,6 +678,17 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vect
|
|||
return %0, %1: vector<8xf32>, vector<2x4xf32>
|
||||
}
|
||||
|
||||
func @reduce_f16(%arg0: vector<16xf16>) -> f16 {
|
||||
%0 = vector.reduction "add", %arg0 : vector<16xf16> into f16
|
||||
return %0 : f16
|
||||
}
|
||||
// CHECK-LABEL: llvm.func @reduce_f16(
|
||||
// CHECK-SAME: %[[A:.*]]: !llvm.vec<16 x half>)
|
||||
// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f16) : !llvm.half
|
||||
// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]])
|
||||
// CHECK-SAME: {reassoc = false} : (!llvm.half, !llvm.vec<16 x half>) -> !llvm.half
|
||||
// CHECK: llvm.return %[[V]] : !llvm.half
|
||||
|
||||
func @reduce_f32(%arg0: vector<16xf32>) -> f32 {
|
||||
%0 = vector.reduction "add", %arg0 : vector<16xf32> into f32
|
||||
return %0 : f32
|
||||
|
@ -700,6 +711,15 @@ func @reduce_f64(%arg0: vector<16xf64>) -> f64 {
|
|||
// CHECK-SAME: {reassoc = false} : (!llvm.double, !llvm.vec<16 x double>) -> !llvm.double
|
||||
// CHECK: llvm.return %[[V]] : !llvm.double
|
||||
|
||||
func @reduce_i8(%arg0: vector<16xi8>) -> i8 {
|
||||
%0 = vector.reduction "add", %arg0 : vector<16xi8> into i8
|
||||
return %0 : i8
|
||||
}
|
||||
// CHECK-LABEL: llvm.func @reduce_i8(
|
||||
// CHECK-SAME: %[[A:.*]]: !llvm.vec<16 x i8>)
|
||||
// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]])
|
||||
// CHECK: llvm.return %[[V]] : !llvm.i8
|
||||
|
||||
func @reduce_i32(%arg0: vector<16xi32>) -> i32 {
|
||||
%0 = vector.reduction "add", %arg0 : vector<16xi32> into i32
|
||||
return %0 : i32
|
||||
|
|
Loading…
Reference in New Issue