Fix linalg.dot over boolean tensors.

dot is currently miscompiled for booleans (uses add instead of or).

Reviewed By: bkramer

Differential Revision: https://reviews.llvm.org/D129292
This commit is contained in:
Johannes Reifferscheid 2022-07-07 20:36:41 +02:00
parent c5be6a8308
commit ad3a078745
3 changed files with 37 additions and 1 deletions

View File

@ -129,7 +129,9 @@ static MatchContractionResult isContractionInterfaceImpl(Operation *op) {
// TODO: more fields than add/mul.
if (!isAddMul<arith::AddFOp, arith::MulFOp>(linalgOp->getRegion(0).front()) &&
!isAddMul<arith::AddIOp, arith::MulIOp>(linalgOp->getRegion(0).front()) &&
!isAddMul<complex::AddOp, complex::MulOp>(linalgOp->getRegion(0).front()))
!isAddMul<complex::AddOp, complex::MulOp>(
linalgOp->getRegion(0).front()) &&
!isAddMul<arith::OrIOp, arith::AndIOp>(linalgOp->getRegion(0).front()))
return MatchContractionResult::NotAddMul;
return MatchContractionResult::Success;
}

View File

@ -325,6 +325,8 @@ public:
bool allComplex = isComplex(arg0) && isComplex(arg1);
bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
bool allInteger = isInteger(arg0) && isInteger(arg1);
bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
arg1.getType().getIntOrFloatBitWidth() == 1;
if (!allComplex && !allFloatingPoint && !allInteger)
llvm_unreachable("unsupported non numeric type");
OpBuilder builder = getBuilder();
@ -334,18 +336,24 @@ public:
return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
if (allBool)
return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::sub:
if (allComplex)
return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
if (allBool)
llvm_unreachable("unsupported operation: sub with bools");
return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::mul:
if (allComplex)
return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
if (allBool)
return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::max_signed:
assert(!allComplex);

View File

@ -137,6 +137,32 @@ func.func @dot(%arg0: memref<?xi8>, %M: index) {
// CHECKPARALLEL: store %[[res]], %[[C]][] : memref<f32>
func.func @dot_int(%arg0: memref<?xi32>, %arg1: memref<?xi32>,
%arg3: memref<i32>) {
// Verifies that we use the correct arith operations for integers.
linalg.dot ins(%arg0, %arg1 : memref<?xi32>, memref<?xi32>)
outs(%arg3 : memref<i32>)
return
}
// CHECK-LABEL: func @dot_int(
// CHECK: %[[inc:.*]] = arith.muli {{.*}} : i32
// CHECK-NEXT: %[[res:.*]] = arith.addi {{.*}}, %[[inc]] : i32
// CHECK-NEXT: store %[[res]], {{.*}} : memref<i32>
func.func @dot_bool(%arg0: memref<?xi1>, %arg1: memref<?xi1>,
%arg3: memref<i1>) {
// Verifies that we use the correct (saturating) arith operations for booleans.
linalg.dot ins(%arg0, %arg1 : memref<?xi1>, memref<?xi1>)
outs(%arg3 : memref<i1>)
return
}
// CHECK-LABEL: func @dot_bool(
// CHECK: %[[inc:.*]] = arith.andi {{.*}} : i1
// CHECK-NEXT: %[[res:.*]] = arith.ori {{.*}}, %[[inc]] : i1
// CHECK-NEXT: store %[[res]], {{.*}} : memref<i1>
func.func @dot_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<f32>) {
linalg.dot ins(%arg0, %arg1 : memref<?xf32, offset: ?, strides: [1]>,
memref<?xf32, offset: ?, strides: [1]>)