forked from OSchip/llvm-project
Fix linalg.dot over boolean tensors.
dot is currently miscompiled for booleans (uses add instead of or). Reviewed By: bkramer Differential Revision: https://reviews.llvm.org/D129292
This commit is contained in:
parent
c5be6a8308
commit
ad3a078745
|
@ -129,7 +129,9 @@ static MatchContractionResult isContractionInterfaceImpl(Operation *op) {
|
|||
// TODO: more fields than add/mul.
|
||||
if (!isAddMul<arith::AddFOp, arith::MulFOp>(linalgOp->getRegion(0).front()) &&
|
||||
!isAddMul<arith::AddIOp, arith::MulIOp>(linalgOp->getRegion(0).front()) &&
|
||||
!isAddMul<complex::AddOp, complex::MulOp>(linalgOp->getRegion(0).front()))
|
||||
!isAddMul<complex::AddOp, complex::MulOp>(
|
||||
linalgOp->getRegion(0).front()) &&
|
||||
!isAddMul<arith::OrIOp, arith::AndIOp>(linalgOp->getRegion(0).front()))
|
||||
return MatchContractionResult::NotAddMul;
|
||||
return MatchContractionResult::Success;
|
||||
}
|
||||
|
|
|
@ -325,6 +325,8 @@ public:
|
|||
bool allComplex = isComplex(arg0) && isComplex(arg1);
|
||||
bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
|
||||
bool allInteger = isInteger(arg0) && isInteger(arg1);
|
||||
bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
|
||||
arg1.getType().getIntOrFloatBitWidth() == 1;
|
||||
if (!allComplex && !allFloatingPoint && !allInteger)
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
OpBuilder builder = getBuilder();
|
||||
|
@ -334,18 +336,24 @@ public:
|
|||
return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
|
||||
if (allBool)
|
||||
return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
|
||||
return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::sub:
|
||||
if (allComplex)
|
||||
return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
|
||||
if (allBool)
|
||||
llvm_unreachable("unsupported operation: sub with bools");
|
||||
return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::mul:
|
||||
if (allComplex)
|
||||
return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
|
||||
if (allBool)
|
||||
return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
|
||||
return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::max_signed:
|
||||
assert(!allComplex);
|
||||
|
|
|
@ -137,6 +137,32 @@ func.func @dot(%arg0: memref<?xi8>, %M: index) {
|
|||
// CHECKPARALLEL: store %[[res]], %[[C]][] : memref<f32>
|
||||
|
||||
|
||||
func.func @dot_int(%arg0: memref<?xi32>, %arg1: memref<?xi32>,
|
||||
%arg3: memref<i32>) {
|
||||
// Verifies that we use the correct arith operations for integers.
|
||||
linalg.dot ins(%arg0, %arg1 : memref<?xi32>, memref<?xi32>)
|
||||
outs(%arg3 : memref<i32>)
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @dot_int(
|
||||
// CHECK: %[[inc:.*]] = arith.muli {{.*}} : i32
|
||||
// CHECK-NEXT: %[[res:.*]] = arith.addi {{.*}}, %[[inc]] : i32
|
||||
// CHECK-NEXT: store %[[res]], {{.*}} : memref<i32>
|
||||
|
||||
|
||||
func.func @dot_bool(%arg0: memref<?xi1>, %arg1: memref<?xi1>,
|
||||
%arg3: memref<i1>) {
|
||||
// Verifies that we use the correct (saturating) arith operations for booleans.
|
||||
linalg.dot ins(%arg0, %arg1 : memref<?xi1>, memref<?xi1>)
|
||||
outs(%arg3 : memref<i1>)
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @dot_bool(
|
||||
// CHECK: %[[inc:.*]] = arith.andi {{.*}} : i1
|
||||
// CHECK-NEXT: %[[res:.*]] = arith.ori {{.*}}, %[[inc]] : i1
|
||||
// CHECK-NEXT: store %[[res]], {{.*}} : memref<i1>
|
||||
|
||||
|
||||
func.func @dot_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<f32>) {
|
||||
linalg.dot ins(%arg0, %arg1 : memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<?xf32, offset: ?, strides: [1]>)
|
||||
|
|
Loading…
Reference in New Issue