Adding min(f/s/u) and max(f/s/u) cases for vector reduction

This PR adds missing AtomicRMWKind::min/max cases which we would like to use for min/max reduction loop vectorizations.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D104881
This commit is contained in:
Alexander Slepko 2021-09-09 11:48:22 -07:00
parent 53438979fe
commit 89837a0e1b
2 changed files with 59 additions and 0 deletions

View File

@ -350,9 +350,32 @@ static LogicalResult verify(AtomicRMWOp op) {
Attribute mlir::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
OpBuilder &builder, Location loc) {
switch (kind) {
case AtomicRMWKind::maxf:
return builder.getFloatAttr(
resultType,
APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
/*Negative=*/true));
case AtomicRMWKind::addf:
case AtomicRMWKind::addi:
case AtomicRMWKind::maxu:
return builder.getZeroAttr(resultType);
case AtomicRMWKind::maxs:
return builder.getIntegerAttr(
resultType,
APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
case AtomicRMWKind::minf:
return builder.getFloatAttr(
resultType,
APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
/*Negative=*/false));
case AtomicRMWKind::mins:
return builder.getIntegerAttr(
resultType,
APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
case AtomicRMWKind::minu:
return builder.getIntegerAttr(
resultType,
APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
case AtomicRMWKind::muli:
return builder.getIntegerAttr(resultType, 1);
case AtomicRMWKind::mulf:
@ -385,6 +408,30 @@ Value mlir::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
return builder.create<MulFOp>(loc, lhs, rhs);
case AtomicRMWKind::muli:
return builder.create<MulIOp>(loc, lhs, rhs);
case AtomicRMWKind::maxf:
return builder.create<SelectOp>(
loc, builder.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs), lhs,
rhs);
case AtomicRMWKind::minf:
return builder.create<SelectOp>(
loc, builder.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs), lhs,
rhs);
case AtomicRMWKind::maxs:
return builder.create<SelectOp>(
loc, builder.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs, rhs), lhs,
rhs);
case AtomicRMWKind::mins:
return builder.create<SelectOp>(
loc, builder.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, rhs), lhs,
rhs);
case AtomicRMWKind::maxu:
return builder.create<SelectOp>(
loc, builder.create<CmpIOp>(loc, CmpIPredicate::ugt, lhs, rhs), lhs,
rhs);
case AtomicRMWKind::minu:
return builder.create<SelectOp>(
loc, builder.create<CmpIOp>(loc, CmpIPredicate::ult, lhs, rhs), lhs,
rhs);
// TODO: Add remaining reduction operations.
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");

View File

@ -357,6 +357,18 @@ Value mlir::vector::getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder,
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("mul"),
vector, ValueRange{});
case AtomicRMWKind::minf:
case AtomicRMWKind::mins:
case AtomicRMWKind::minu:
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("min"),
vector, ValueRange{});
case AtomicRMWKind::maxf:
case AtomicRMWKind::maxs:
case AtomicRMWKind::maxu:
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("max"),
vector, ValueRange{});
// TODO: Add remaining reduction operations.
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");