forked from OSchip/llvm-project
Adding min(f/s/u) and max(f/s/u) cases for vector reduction
This PR adds missing AtomicRMWKind::min/max cases which we would like to use for min/max reduction loop vectorizations. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D104881
This commit is contained in:
parent
53438979fe
commit
89837a0e1b
|
@ -350,9 +350,32 @@ static LogicalResult verify(AtomicRMWOp op) {
|
|||
Attribute mlir::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
|
||||
OpBuilder &builder, Location loc) {
|
||||
switch (kind) {
|
||||
case AtomicRMWKind::maxf:
|
||||
return builder.getFloatAttr(
|
||||
resultType,
|
||||
APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
|
||||
/*Negative=*/true));
|
||||
case AtomicRMWKind::addf:
|
||||
case AtomicRMWKind::addi:
|
||||
case AtomicRMWKind::maxu:
|
||||
return builder.getZeroAttr(resultType);
|
||||
case AtomicRMWKind::maxs:
|
||||
return builder.getIntegerAttr(
|
||||
resultType,
|
||||
APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
|
||||
case AtomicRMWKind::minf:
|
||||
return builder.getFloatAttr(
|
||||
resultType,
|
||||
APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
|
||||
/*Negative=*/false));
|
||||
case AtomicRMWKind::mins:
|
||||
return builder.getIntegerAttr(
|
||||
resultType,
|
||||
APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
|
||||
case AtomicRMWKind::minu:
|
||||
return builder.getIntegerAttr(
|
||||
resultType,
|
||||
APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
|
||||
case AtomicRMWKind::muli:
|
||||
return builder.getIntegerAttr(resultType, 1);
|
||||
case AtomicRMWKind::mulf:
|
||||
|
@ -385,6 +408,30 @@ Value mlir::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
|
|||
return builder.create<MulFOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::muli:
|
||||
return builder.create<MulIOp>(loc, lhs, rhs);
|
||||
case AtomicRMWKind::maxf:
|
||||
return builder.create<SelectOp>(
|
||||
loc, builder.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs), lhs,
|
||||
rhs);
|
||||
case AtomicRMWKind::minf:
|
||||
return builder.create<SelectOp>(
|
||||
loc, builder.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs), lhs,
|
||||
rhs);
|
||||
case AtomicRMWKind::maxs:
|
||||
return builder.create<SelectOp>(
|
||||
loc, builder.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs, rhs), lhs,
|
||||
rhs);
|
||||
case AtomicRMWKind::mins:
|
||||
return builder.create<SelectOp>(
|
||||
loc, builder.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, rhs), lhs,
|
||||
rhs);
|
||||
case AtomicRMWKind::maxu:
|
||||
return builder.create<SelectOp>(
|
||||
loc, builder.create<CmpIOp>(loc, CmpIPredicate::ugt, lhs, rhs), lhs,
|
||||
rhs);
|
||||
case AtomicRMWKind::minu:
|
||||
return builder.create<SelectOp>(
|
||||
loc, builder.create<CmpIOp>(loc, CmpIPredicate::ult, lhs, rhs), lhs,
|
||||
rhs);
|
||||
// TODO: Add remaining reduction operations.
|
||||
default:
|
||||
(void)emitOptionalError(loc, "Reduction operation type not supported");
|
||||
|
|
|
@ -357,6 +357,18 @@ Value mlir::vector::getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder,
|
|||
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
|
||||
builder.getStringAttr("mul"),
|
||||
vector, ValueRange{});
|
||||
case AtomicRMWKind::minf:
|
||||
case AtomicRMWKind::mins:
|
||||
case AtomicRMWKind::minu:
|
||||
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
|
||||
builder.getStringAttr("min"),
|
||||
vector, ValueRange{});
|
||||
case AtomicRMWKind::maxf:
|
||||
case AtomicRMWKind::maxs:
|
||||
case AtomicRMWKind::maxu:
|
||||
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
|
||||
builder.getStringAttr("max"),
|
||||
vector, ValueRange{});
|
||||
// TODO: Add remaining reduction operations.
|
||||
default:
|
||||
(void)emitOptionalError(loc, "Reduction operation type not supported");
|
||||
|
|
Loading…
Reference in New Issue