forked from OSchip/llvm-project
Adding min(f/s/u) and max(f/s/u) cases for vector reduction
This PR adds missing AtomicRMWKind::min/max cases which we would like to use for min/max reduction loop vectorizations. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D104881
This commit is contained in:
parent
53438979fe
commit
89837a0e1b
|
@ -350,9 +350,32 @@ static LogicalResult verify(AtomicRMWOp op) {
|
||||||
Attribute mlir::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
|
Attribute mlir::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
|
||||||
OpBuilder &builder, Location loc) {
|
OpBuilder &builder, Location loc) {
|
||||||
switch (kind) {
|
switch (kind) {
|
||||||
|
case AtomicRMWKind::maxf:
|
||||||
|
return builder.getFloatAttr(
|
||||||
|
resultType,
|
||||||
|
APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
|
||||||
|
/*Negative=*/true));
|
||||||
case AtomicRMWKind::addf:
|
case AtomicRMWKind::addf:
|
||||||
case AtomicRMWKind::addi:
|
case AtomicRMWKind::addi:
|
||||||
|
case AtomicRMWKind::maxu:
|
||||||
return builder.getZeroAttr(resultType);
|
return builder.getZeroAttr(resultType);
|
||||||
|
case AtomicRMWKind::maxs:
|
||||||
|
return builder.getIntegerAttr(
|
||||||
|
resultType,
|
||||||
|
APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
|
||||||
|
case AtomicRMWKind::minf:
|
||||||
|
return builder.getFloatAttr(
|
||||||
|
resultType,
|
||||||
|
APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
|
||||||
|
/*Negative=*/false));
|
||||||
|
case AtomicRMWKind::mins:
|
||||||
|
return builder.getIntegerAttr(
|
||||||
|
resultType,
|
||||||
|
APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
|
||||||
|
case AtomicRMWKind::minu:
|
||||||
|
return builder.getIntegerAttr(
|
||||||
|
resultType,
|
||||||
|
APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
|
||||||
case AtomicRMWKind::muli:
|
case AtomicRMWKind::muli:
|
||||||
return builder.getIntegerAttr(resultType, 1);
|
return builder.getIntegerAttr(resultType, 1);
|
||||||
case AtomicRMWKind::mulf:
|
case AtomicRMWKind::mulf:
|
||||||
|
@ -385,6 +408,30 @@ Value mlir::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
|
||||||
return builder.create<MulFOp>(loc, lhs, rhs);
|
return builder.create<MulFOp>(loc, lhs, rhs);
|
||||||
case AtomicRMWKind::muli:
|
case AtomicRMWKind::muli:
|
||||||
return builder.create<MulIOp>(loc, lhs, rhs);
|
return builder.create<MulIOp>(loc, lhs, rhs);
|
||||||
|
case AtomicRMWKind::maxf:
|
||||||
|
return builder.create<SelectOp>(
|
||||||
|
loc, builder.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs), lhs,
|
||||||
|
rhs);
|
||||||
|
case AtomicRMWKind::minf:
|
||||||
|
return builder.create<SelectOp>(
|
||||||
|
loc, builder.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs), lhs,
|
||||||
|
rhs);
|
||||||
|
case AtomicRMWKind::maxs:
|
||||||
|
return builder.create<SelectOp>(
|
||||||
|
loc, builder.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs, rhs), lhs,
|
||||||
|
rhs);
|
||||||
|
case AtomicRMWKind::mins:
|
||||||
|
return builder.create<SelectOp>(
|
||||||
|
loc, builder.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, rhs), lhs,
|
||||||
|
rhs);
|
||||||
|
case AtomicRMWKind::maxu:
|
||||||
|
return builder.create<SelectOp>(
|
||||||
|
loc, builder.create<CmpIOp>(loc, CmpIPredicate::ugt, lhs, rhs), lhs,
|
||||||
|
rhs);
|
||||||
|
case AtomicRMWKind::minu:
|
||||||
|
return builder.create<SelectOp>(
|
||||||
|
loc, builder.create<CmpIOp>(loc, CmpIPredicate::ult, lhs, rhs), lhs,
|
||||||
|
rhs);
|
||||||
// TODO: Add remaining reduction operations.
|
// TODO: Add remaining reduction operations.
|
||||||
default:
|
default:
|
||||||
(void)emitOptionalError(loc, "Reduction operation type not supported");
|
(void)emitOptionalError(loc, "Reduction operation type not supported");
|
||||||
|
|
|
@ -357,6 +357,18 @@ Value mlir::vector::getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder,
|
||||||
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
|
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
|
||||||
builder.getStringAttr("mul"),
|
builder.getStringAttr("mul"),
|
||||||
vector, ValueRange{});
|
vector, ValueRange{});
|
||||||
|
case AtomicRMWKind::minf:
|
||||||
|
case AtomicRMWKind::mins:
|
||||||
|
case AtomicRMWKind::minu:
|
||||||
|
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
|
||||||
|
builder.getStringAttr("min"),
|
||||||
|
vector, ValueRange{});
|
||||||
|
case AtomicRMWKind::maxf:
|
||||||
|
case AtomicRMWKind::maxs:
|
||||||
|
case AtomicRMWKind::maxu:
|
||||||
|
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
|
||||||
|
builder.getStringAttr("max"),
|
||||||
|
vector, ValueRange{});
|
||||||
// TODO: Add remaining reduction operations.
|
// TODO: Add remaining reduction operations.
|
||||||
default:
|
default:
|
||||||
(void)emitOptionalError(loc, "Reduction operation type not supported");
|
(void)emitOptionalError(loc, "Reduction operation type not supported");
|
||||||
|
|
Loading…
Reference in New Issue