[flang] Add an explicit condition for the BITS=0 case in the MASKL and MASKR intrinsics.

This fixes issue #56706.

Differential Revision: https://reviews.llvm.org/D130590
This commit is contained in:
Tarun Prabhu 2022-08-08 08:56:29 -06:00
parent e2bfbed2bb
commit c1f65df19c
3 changed files with 46 additions and 11 deletions

View File

@ -3297,6 +3297,7 @@ mlir::Value IntrinsicLibrary::genMask(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 2);
mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0);
mlir::Value ones = builder.createIntegerConstant(loc, resultType, -1);
mlir::Value bitSize = builder.createIntegerConstant(
loc, resultType, resultType.getIntOrFloatBitWidth());
@ -3309,7 +3310,11 @@ mlir::Value IntrinsicLibrary::genMask(mlir::Type resultType,
// in this case either, so we choose the most efficient implementation.
mlir::Value shift =
builder.create<mlir::arith::SubIOp>(loc, bitSize, bitsToSet);
return builder.create<Shift>(loc, ones, shift);
mlir::Value shifted = builder.create<Shift>(loc, ones, shift);
mlir::Value isZero = builder.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::eq, bitsToSet, zero);
return builder.create<mlir::arith::SelectOp>(loc, isZero, zero, shifted);
}
// MATMUL

View File

@ -9,11 +9,14 @@ subroutine maskl_test(a, b)
! CHECK: %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<i32>
b = maskl(a)
! CHECK: %[[C__0:.*]] = arith.constant 0 : i32
! CHECK: %[[C__1:.*]] = arith.constant -1 : i32
! CHECK: %[[BITS:.*]] = arith.constant 32 : i32
! CHECK: %[[LEN:.*]] = arith.subi %[[BITS]], %[[A_VAL]] : i32
! CHECK: %[[SHIFT:.*]] = arith.shli %[[C__1]], %[[LEN]] : i32
! CHECK: fir.store %[[SHIFT]] to %[[B]] : !fir.ref<i32>
! CHECK: %[[IS0:.*]] = arith.cmpi eq, %[[A_VAL]], %[[C__0]] : i32
! CHECK: %[[RESULT:.*]] = arith.select %[[IS0]], %[[C__0]], %[[SHIFT]] : i32
! CHECK: fir.store %[[RESULT]] to %[[B]] : !fir.ref<i32>
end subroutine maskl_test
! CHECK-LABEL: maskl1_test
@ -24,12 +27,15 @@ subroutine maskl1_test(a, b)
! CHECK: %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<i32>
b = maskl(a, 1)
! CHECK: %[[C__0:.*]] = arith.constant 0 : i8
! CHECK: %[[C__1:.*]] = arith.constant -1 : i8
! CHECK: %[[BITS:.*]] = arith.constant 8 : i8
! CHECK: %[[A_CONV:.*]] = fir.convert %[[A_VAL]] : (i32) -> i8
! CHECK: %[[LEN:.*]] = arith.subi %[[BITS]], %[[A_CONV]] : i8
! CHECK: %[[SHIFT:.*]] = arith.shli %[[C__1]], %[[LEN]] : i8
! CHECK: fir.store %[[SHIFT]] to %[[B]] : !fir.ref<i8>
! CHECK: %[[IS0:.*]] = arith.cmpi eq, %[[A_CONV]], %[[C__0]] : i8
! CHECK: %[[RESULT:.*]] = arith.select %[[IS0]], %[[C__0]], %[[SHIFT]] : i8
! CHECK: fir.store %[[RESULT]] to %[[B]] : !fir.ref<i8>
end subroutine maskl1_test
! CHECK-LABEL: maskl2_test
@ -40,12 +46,15 @@ subroutine maskl2_test(a, b)
! CHECK: %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<i32>
b = maskl(a, 2)
! CHECK: %[[C__0:.*]] = arith.constant 0 : i16
! CHECK: %[[C__1:.*]] = arith.constant -1 : i16
! CHECK: %[[BITS:.*]] = arith.constant 16 : i16
! CHECK: %[[A_CONV:.*]] = fir.convert %[[A_VAL]] : (i32) -> i16
! CHECK: %[[LEN:.*]] = arith.subi %[[BITS]], %[[A_CONV]] : i16
! CHECK: %[[SHIFT:.*]] = arith.shli %[[C__1]], %[[LEN]] : i16
! CHECK: fir.store %[[SHIFT]] to %[[B]] : !fir.ref<i16>
! CHECK: %[[IS0:.*]] = arith.cmpi eq, %[[A_CONV]], %[[C__0]] : i16
! CHECK: %[[RESULT:.*]] = arith.select %[[IS0]], %[[C__0]], %[[SHIFT]] : i16
! CHECK: fir.store %[[RESULT]] to %[[B]] : !fir.ref<i16>
end subroutine maskl2_test
! CHECK-LABEL: maskl4_test
@ -56,11 +65,14 @@ subroutine maskl4_test(a, b)
! CHECK: %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<i32>
b = maskl(a, 4)
! CHECK: %[[C__0:.*]] = arith.constant 0 : i32
! CHECK: %[[C__1:.*]] = arith.constant -1 : i32
! CHECK: %[[BITS:.*]] = arith.constant 32 : i32
! CHECK: %[[LEN:.*]] = arith.subi %[[BITS]], %[[A_VAL]] : i32
! CHECK: %[[SHIFT:.*]] = arith.shli %[[C__1]], %[[LEN]] : i32
! CHECK: fir.store %[[SHIFT]] to %[[B]] : !fir.ref<i32>
! CHECK: %[[IS0:.*]] = arith.cmpi eq, %[[A_VAL]], %[[C__0]] : i32
! CHECK: %[[RESULT:.*]] = arith.select %[[IS0]], %[[C__0]], %[[SHIFT]] : i32
! CHECK: fir.store %[[RESULT]] to %[[B]] : !fir.ref<i32>
end subroutine maskl4_test
! CHECK-LABEL: maskl8_test
@ -71,12 +83,15 @@ subroutine maskl8_test(a, b)
! CHECK: %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<i32>
b = maskl(a, 8)
! CHECK: %[[C__0:.*]] = arith.constant 0 : i64
! CHECK: %[[C__1:.*]] = arith.constant -1 : i64
! CHECK: %[[BITS:.*]] = arith.constant 64 : i64
! CHECK: %[[A_CONV:.*]] = fir.convert %[[A_VAL]] : (i32) -> i64
! CHECK: %[[LEN:.*]] = arith.subi %[[BITS]], %[[A_CONV]] : i64
! CHECK: %[[SHIFT:.*]] = arith.shli %[[C__1]], %[[LEN]] : i64
! CHECK: fir.store %[[SHIFT]] to %[[B]] : !fir.ref<i64>
! CHECK: %[[IS0:.*]] = arith.cmpi eq, %[[A_CONV]], %[[C__0]] : i64
! CHECK: %[[RESULT:.*]] = arith.select %[[IS0]], %[[C__0]], %[[SHIFT]] : i64
! CHECK: fir.store %[[RESULT]] to %[[B]] : !fir.ref<i64>
end subroutine maskl8_test
! TODO: Code containing 128-bit integer literals current breaks. This is

View File

@ -9,11 +9,14 @@ subroutine maskr_test(a, b)
! CHECK: %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<i32>
b = maskr(a)
! CHECK: %[[C__0:.*]] = arith.constant 0 : i32
! CHECK: %[[C__1:.*]] = arith.constant -1 : i32
! CHECK: %[[BITS:.*]] = arith.constant 32 : i32
! CHECK: %[[LEN:.*]] = arith.subi %[[BITS]], %[[A_VAL]] : i32
! CHECK: %[[SHIFT:.*]] = arith.shrui %[[C__1]], %[[LEN]] : i32
! CHECK: fir.store %[[SHIFT]] to %[[B]] : !fir.ref<i32>
! CHECK: %[[IS0:.*]] = arith.cmpi eq, %[[A_VAL]], %[[C__0]] : i32
! CHECK: %[[RESULT:.*]] = arith.select %[[IS0]], %[[C__0]], %[[SHIFT]] : i32
! CHECK: fir.store %[[RESULT]] to %[[B]] : !fir.ref<i32>
end subroutine maskr_test
! CHECK-LABEL: maskr1_test
@ -24,12 +27,15 @@ subroutine maskr1_test(a, b)
! CHECK: %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<i32>
b = maskr(a, 1)
! CHECK: %[[C__0:.*]] = arith.constant 0 : i8
! CHECK: %[[C__1:.*]] = arith.constant -1 : i8
! CHECK: %[[BITS:.*]] = arith.constant 8 : i8
! CHECK: %[[A_CONV:.*]] = fir.convert %[[A_VAL]] : (i32) -> i8
! CHECK: %[[LEN:.*]] = arith.subi %[[BITS]], %[[A_CONV]] : i8
! CHECK: %[[SHIFT:.*]] = arith.shrui %[[C__1]], %[[LEN]] : i8
! CHECK: fir.store %[[SHIFT]] to %[[B]] : !fir.ref<i8>
! CHECK: %[[IS0:.*]] = arith.cmpi eq, %[[A_CONV]], %[[C__0]] : i8
! CHECK: %[[RESULT:.*]] = arith.select %[[IS0]], %[[C__0]], %[[SHIFT]] : i8
! CHECK: fir.store %[[RESULT]] to %[[B]] : !fir.ref<i8>
end subroutine maskr1_test
! CHECK-LABEL: maskr2_test
@ -40,12 +46,15 @@ subroutine maskr2_test(a, b)
! CHECK: %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<i32>
b = maskr(a, 2)
! CHECK: %[[C__0:.*]] = arith.constant 0 : i16
! CHECK: %[[C__1:.*]] = arith.constant -1 : i16
! CHECK: %[[BITS:.*]] = arith.constant 16 : i16
! CHECK: %[[A_CONV:.*]] = fir.convert %[[A_VAL]] : (i32) -> i16
! CHECK: %[[LEN:.*]] = arith.subi %[[BITS]], %[[A_CONV]] : i16
! CHECK: %[[SHIFT:.*]] = arith.shrui %[[C__1]], %[[LEN]] : i16
! CHECK: fir.store %[[SHIFT]] to %[[B]] : !fir.ref<i16>
! CHECK: %[[IS0:.*]] = arith.cmpi eq, %[[A_CONV]], %[[C__0]] : i16
! CHECK: %[[RESULT:.*]] = arith.select %[[IS0]], %[[C__0]], %[[SHIFT]] : i16
! CHECK: fir.store %[[RESULT]] to %[[B]] : !fir.ref<i16>
end subroutine maskr2_test
! CHECK-LABEL: maskr4_test
@ -56,11 +65,14 @@ subroutine maskr4_test(a, b)
! CHECK: %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<i32>
b = maskr(a, 4)
! CHECK: %[[C__0:.*]] = arith.constant 0 : i32
! CHECK: %[[C__1:.*]] = arith.constant -1 : i32
! CHECK: %[[BITS:.*]] = arith.constant 32 : i32
! CHECK: %[[LEN:.*]] = arith.subi %[[BITS]], %[[A_VAL]] : i32
! CHECK: %[[SHIFT:.*]] = arith.shrui %[[C__1]], %[[LEN]] : i32
! CHECK: fir.store %[[SHIFT]] to %[[B]] : !fir.ref<i32>
! CHECK: %[[IS0:.*]] = arith.cmpi eq, %[[A_VAL]], %[[C__0]] : i32
! CHECK: %[[RESULT:.*]] = arith.select %[[IS0]], %[[C__0]], %[[SHIFT]] : i32
! CHECK: fir.store %[[RESULT]] to %[[B]] : !fir.ref<i32>
end subroutine maskr4_test
! CHECK-LABEL: maskr8_test
@ -71,12 +83,15 @@ subroutine maskr8_test(a, b)
! CHECK: %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<i32>
b = maskr(a, 8)
! CHECK: %[[C__0:.*]] = arith.constant 0 : i64
! CHECK: %[[C__1:.*]] = arith.constant -1 : i64
! CHECK: %[[BITS:.*]] = arith.constant 64 : i64
! CHECK: %[[A_CONV:.*]] = fir.convert %[[A_VAL]] : (i32) -> i64
! CHECK: %[[LEN:.*]] = arith.subi %[[BITS]], %[[A_CONV]] : i64
! CHECK: %[[SHIFT:.*]] = arith.shrui %[[C__1]], %[[LEN]] : i64
! CHECK: fir.store %[[SHIFT]] to %[[B]] : !fir.ref<i64>
! CHECK: %[[IS0:.*]] = arith.cmpi eq, %[[A_CONV]], %[[C__0]] : i64
! CHECK: %[[RESULT:.*]] = arith.select %[[IS0]], %[[C__0]], %[[SHIFT]] : i64
! CHECK: fir.store %[[RESULT]] to %[[B]] : !fir.ref<i64>
end subroutine maskr8_test
! TODO: Code containing 128-bit integer literals current breaks. This is