forked from OSchip/llvm-project
[mlir][vector] Fix crash in vector.reduction canonicalization
since vector.reduce support accumulator in all the cases remove the assert assuming old definition. Differential Revision: https://reviews.llvm.org/D129602
This commit is contained in:
parent
cc7d966511
commit
5f8cefebd9
|
@ -182,6 +182,11 @@ bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
|
|||
/// memory.
|
||||
bool isDisjointTransferSet(VectorTransferOpInterface transferA,
|
||||
VectorTransferOpInterface transferB);
|
||||
|
||||
/// Return the result value of reducing two scalar/vector values with the
|
||||
/// corresponding arith operation.
|
||||
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
|
||||
Value v1, Value v2);
|
||||
} // namespace vector
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -34,11 +34,6 @@ namespace vector {
|
|||
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
|
||||
/// the type of `source`.
|
||||
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
|
||||
|
||||
/// Return the result value of reducing two scalar/vector values with the
|
||||
/// corresponding arith operation.
|
||||
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
|
||||
Value v1, Value v2);
|
||||
} // namespace vector
|
||||
|
||||
/// Return the number of elements of basis, `0` if empty.
|
||||
|
|
|
@ -501,19 +501,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
|
|||
reductionOp.getVector(),
|
||||
rewriter.getI64ArrayAttr(0));
|
||||
|
||||
if (Value acc = reductionOp.getAcc()) {
|
||||
assert(reductionOp.getType().isa<FloatType>());
|
||||
switch (reductionOp.getKind()) {
|
||||
case CombiningKind::ADD:
|
||||
result = rewriter.create<arith::AddFOp>(loc, result, acc);
|
||||
break;
|
||||
case CombiningKind::MUL:
|
||||
result = rewriter.create<arith::MulFOp>(loc, result, acc);
|
||||
break;
|
||||
default:
|
||||
assert(false && "invalid op!");
|
||||
}
|
||||
}
|
||||
if (Value acc = reductionOp.getAcc())
|
||||
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
|
||||
result, acc);
|
||||
|
||||
rewriter.replaceOp(reductionOp, result);
|
||||
return success();
|
||||
|
@ -5007,6 +4997,56 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
|
|||
verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
|
||||
}
|
||||
|
||||
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
|
||||
CombiningKind kind, Value v1, Value v2) {
|
||||
Type t1 = getElementTypeOrSelf(v1.getType());
|
||||
Type t2 = getElementTypeOrSelf(v2.getType());
|
||||
switch (kind) {
|
||||
case CombiningKind::ADD:
|
||||
if (t1.isIntOrIndex() && t2.isIntOrIndex())
|
||||
return b.createOrFold<arith::AddIOp>(loc, v1, v2);
|
||||
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
|
||||
return b.createOrFold<arith::AddFOp>(loc, v1, v2);
|
||||
llvm_unreachable("invalid value types for ADD reduction");
|
||||
case CombiningKind::AND:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::AndIOp>(loc, v1, v2);
|
||||
case CombiningKind::MAXF:
|
||||
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
|
||||
"expected float values");
|
||||
return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
|
||||
case CombiningKind::MINF:
|
||||
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
|
||||
"expected float values");
|
||||
return b.createOrFold<arith::MinFOp>(loc, v1, v2);
|
||||
case CombiningKind::MAXSI:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
|
||||
case CombiningKind::MINSI:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
|
||||
case CombiningKind::MAXUI:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
|
||||
case CombiningKind::MINUI:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
|
||||
case CombiningKind::MUL:
|
||||
if (t1.isIntOrIndex() && t2.isIntOrIndex())
|
||||
return b.createOrFold<arith::MulIOp>(loc, v1, v2);
|
||||
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
|
||||
return b.createOrFold<arith::MulFOp>(loc, v1, v2);
|
||||
llvm_unreachable("invalid value types for MUL reduction");
|
||||
case CombiningKind::OR:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::OrIOp>(loc, v1, v2);
|
||||
case CombiningKind::XOR:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
|
||||
};
|
||||
llvm_unreachable("unknown CombiningKind");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -43,56 +43,6 @@ Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
|
|||
llvm_unreachable("Expected MemRefType or TensorType");
|
||||
}
|
||||
|
||||
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
|
||||
CombiningKind kind, Value v1, Value v2) {
|
||||
Type t1 = getElementTypeOrSelf(v1.getType());
|
||||
Type t2 = getElementTypeOrSelf(v2.getType());
|
||||
switch (kind) {
|
||||
case CombiningKind::ADD:
|
||||
if (t1.isIntOrIndex() && t2.isIntOrIndex())
|
||||
return b.createOrFold<arith::AddIOp>(loc, v1, v2);
|
||||
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
|
||||
return b.createOrFold<arith::AddFOp>(loc, v1, v2);
|
||||
llvm_unreachable("invalid value types for ADD reduction");
|
||||
case CombiningKind::AND:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::AndIOp>(loc, v1, v2);
|
||||
case CombiningKind::MAXF:
|
||||
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
|
||||
"expected float values");
|
||||
return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
|
||||
case CombiningKind::MINF:
|
||||
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
|
||||
"expected float values");
|
||||
return b.createOrFold<arith::MinFOp>(loc, v1, v2);
|
||||
case CombiningKind::MAXSI:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
|
||||
case CombiningKind::MINSI:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
|
||||
case CombiningKind::MAXUI:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
|
||||
case CombiningKind::MINUI:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
|
||||
case CombiningKind::MUL:
|
||||
if (t1.isIntOrIndex() && t2.isIntOrIndex())
|
||||
return b.createOrFold<arith::MulIOp>(loc, v1, v2);
|
||||
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
|
||||
return b.createOrFold<arith::MulFOp>(loc, v1, v2);
|
||||
llvm_unreachable("invalid value types for MUL reduction");
|
||||
case CombiningKind::OR:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::OrIOp>(loc, v1, v2);
|
||||
case CombiningKind::XOR:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
|
||||
};
|
||||
llvm_unreachable("unknown CombiningKind");
|
||||
}
|
||||
|
||||
/// Return the number of elements of basis, `0` if empty.
|
||||
int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
|
||||
if (basis.empty())
|
||||
|
|
|
@ -1619,6 +1619,18 @@ func.func @dont_reduce_one_element_vector(%a : vector<4xf32>) -> f32 {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @reduce_one_element_vector_maxf
|
||||
// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
|
||||
// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>
|
||||
// CHECK: %[[S:.+]] = arith.maxf %[[A]], %[[B]] : f32
|
||||
// CHECK: return %[[S]]
|
||||
func.func @reduce_one_element_vector_maxf(%a : vector<1xf32>, %b: f32) -> f32 {
|
||||
%s = vector.reduction <maxf>, %a, %b : vector<1xf32> into f32
|
||||
return %s : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @bitcast(
|
||||
// CHECK-SAME: %[[ARG:.*]]: vector<4x8xf32>) -> vector<4x16xi16> {
|
||||
// CHECK: vector.bitcast %[[ARG:.*]] : vector<4x8xf32> to vector<4x16xi16>
|
||||
|
|
Loading…
Reference in New Issue