[mlir][vector] Fix crash in vector.reduction canonicalization

since vector.reduce support accumulator in all the cases remove the
assert assuming old definition.

Differential Revision: https://reviews.llvm.org/D129602
This commit is contained in:
Thomas Raoux 2022-07-12 22:44:39 +00:00
parent cc7d966511
commit 5f8cefebd9
5 changed files with 70 additions and 68 deletions

View File

@ -182,6 +182,11 @@ bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
/// memory.
bool isDisjointTransferSet(VectorTransferOpInterface transferA,
VectorTransferOpInterface transferB);
/// Return the result value of reducing two scalar/vector values with the
/// corresponding arith operation.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
Value v1, Value v2);
} // namespace vector
} // namespace mlir

View File

@ -34,11 +34,6 @@ namespace vector {
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
/// Return the result value of reducing two scalar/vector values with the
/// corresponding arith operation.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
Value v1, Value v2);
} // namespace vector
/// Return the number of elements of basis, `0` if empty.

View File

@ -501,19 +501,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
reductionOp.getVector(),
rewriter.getI64ArrayAttr(0));
if (Value acc = reductionOp.getAcc()) {
assert(reductionOp.getType().isa<FloatType>());
switch (reductionOp.getKind()) {
case CombiningKind::ADD:
result = rewriter.create<arith::AddFOp>(loc, result, acc);
break;
case CombiningKind::MUL:
result = rewriter.create<arith::MulFOp>(loc, result, acc);
break;
default:
assert(false && "invalid op!");
}
}
if (Value acc = reductionOp.getAcc())
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
result, acc);
rewriter.replaceOp(reductionOp, result);
return success();
@ -5007,6 +4997,56 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
}
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
CombiningKind kind, Value v1, Value v2) {
Type t1 = getElementTypeOrSelf(v1.getType());
Type t2 = getElementTypeOrSelf(v2.getType());
switch (kind) {
case CombiningKind::ADD:
if (t1.isIntOrIndex() && t2.isIntOrIndex())
return b.createOrFold<arith::AddIOp>(loc, v1, v2);
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
return b.createOrFold<arith::AddFOp>(loc, v1, v2);
llvm_unreachable("invalid value types for ADD reduction");
case CombiningKind::AND:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::AndIOp>(loc, v1, v2);
case CombiningKind::MAXF:
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
"expected float values");
return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
case CombiningKind::MINF:
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
"expected float values");
return b.createOrFold<arith::MinFOp>(loc, v1, v2);
case CombiningKind::MAXSI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
case CombiningKind::MINSI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
case CombiningKind::MAXUI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
case CombiningKind::MINUI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
case CombiningKind::MUL:
if (t1.isIntOrIndex() && t2.isIntOrIndex())
return b.createOrFold<arith::MulIOp>(loc, v1, v2);
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
return b.createOrFold<arith::MulFOp>(loc, v1, v2);
llvm_unreachable("invalid value types for MUL reduction");
case CombiningKind::OR:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::OrIOp>(loc, v1, v2);
case CombiningKind::XOR:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
};
llvm_unreachable("unknown CombiningKind");
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

View File

@ -43,56 +43,6 @@ Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
llvm_unreachable("Expected MemRefType or TensorType");
}
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
CombiningKind kind, Value v1, Value v2) {
Type t1 = getElementTypeOrSelf(v1.getType());
Type t2 = getElementTypeOrSelf(v2.getType());
switch (kind) {
case CombiningKind::ADD:
if (t1.isIntOrIndex() && t2.isIntOrIndex())
return b.createOrFold<arith::AddIOp>(loc, v1, v2);
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
return b.createOrFold<arith::AddFOp>(loc, v1, v2);
llvm_unreachable("invalid value types for ADD reduction");
case CombiningKind::AND:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::AndIOp>(loc, v1, v2);
case CombiningKind::MAXF:
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
"expected float values");
return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
case CombiningKind::MINF:
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
"expected float values");
return b.createOrFold<arith::MinFOp>(loc, v1, v2);
case CombiningKind::MAXSI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
case CombiningKind::MINSI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
case CombiningKind::MAXUI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
case CombiningKind::MINUI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
case CombiningKind::MUL:
if (t1.isIntOrIndex() && t2.isIntOrIndex())
return b.createOrFold<arith::MulIOp>(loc, v1, v2);
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
return b.createOrFold<arith::MulFOp>(loc, v1, v2);
llvm_unreachable("invalid value types for MUL reduction");
case CombiningKind::OR:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::OrIOp>(loc, v1, v2);
case CombiningKind::XOR:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
};
llvm_unreachable("unknown CombiningKind");
}
/// Return the number of elements of basis, `0` if empty.
int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
if (basis.empty())

View File

@ -1619,6 +1619,18 @@ func.func @dont_reduce_one_element_vector(%a : vector<4xf32>) -> f32 {
// -----
// CHECK-LABEL: func @reduce_one_element_vector_maxf
// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>
// CHECK: %[[S:.+]] = arith.maxf %[[A]], %[[B]] : f32
// CHECK: return %[[S]]
func.func @reduce_one_element_vector_maxf(%a : vector<1xf32>, %b: f32) -> f32 {
%s = vector.reduction <maxf>, %a, %b : vector<1xf32> into f32
return %s : f32
}
// -----
// CHECK-LABEL: func @bitcast(
// CHECK-SAME: %[[ARG:.*]]: vector<4x8xf32>) -> vector<4x16xi16> {
// CHECK: vector.bitcast %[[ARG:.*]] : vector<4x8xf32> to vector<4x16xi16>