[flang] Lower mvbits intrinsic

This patch adds the lowering for the `mvbits`
intrinsic.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D122412

Co-authored-by: V Donaldson <vdonaldson@nvidia.com>
Co-authored-by: Jean Perier <jperier@nvidia.com>
This commit is contained in:
Valentin Clement 2022-03-25 08:00:10 +01:00
parent 56a54910c5
commit 50354558a7
No known key found for this signature in database
GPG Key ID: 086D54783C928776
2 changed files with 131 additions and 0 deletions

View File

@ -499,6 +499,7 @@ struct IntrinsicLibrary {
fir::ExtendedValue genMinval(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genMod(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genModulo(mlir::Type, llvm::ArrayRef<mlir::Value>);
void genMvbits(llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genNearest(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genNint(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genNot(mlir::Type, llvm::ArrayRef<mlir::Value>);
@ -805,6 +806,13 @@ static constexpr IntrinsicHandler handlers[]{
/*isElemental=*/false},
{"mod", &I::genMod},
{"modulo", &I::genModulo},
{"mvbits",
&I::genMvbits,
{{{"from", asValue},
{"frompos", asValue},
{"len", asValue},
{"to", asAddr},
{"topos", asValue}}}},
{"nearest", &I::genNearest},
{"nint", &I::genNint},
{"not", &I::genNot},
@ -2854,6 +2862,53 @@ mlir::Value IntrinsicLibrary::genModulo(mlir::Type resultType,
remainder);
}
// MVBITS
void IntrinsicLibrary::genMvbits(llvm::ArrayRef<fir::ExtendedValue> args) {
// A conformant MVBITS(FROM,FROMPOS,LEN,TO,TOPOS) call satisfies:
// FROMPOS >= 0
// LEN >= 0
// TOPOS >= 0
// FROMPOS + LEN <= BIT_SIZE(FROM)
// TOPOS + LEN <= BIT_SIZE(TO)
// MASK = -1 >> (BIT_SIZE(FROM) - LEN)
// TO = LEN == 0 ? TO : ((!(MASK << TOPOS)) & TO) |
// (((FROM >> FROMPOS) & MASK) << TOPOS)
assert(args.size() == 5);
auto unbox = [&](fir::ExtendedValue exv) {
const mlir::Value *arg = exv.getUnboxed();
assert(arg && "nonscalar mvbits argument");
return *arg;
};
mlir::Value from = unbox(args[0]);
mlir::Type resultType = from.getType();
mlir::Value frompos = builder.createConvert(loc, resultType, unbox(args[1]));
mlir::Value len = builder.createConvert(loc, resultType, unbox(args[2]));
mlir::Value toAddr = unbox(args[3]);
assert(fir::dyn_cast_ptrEleTy(toAddr.getType()) == resultType &&
"mismatched mvbits types");
auto to = builder.create<fir::LoadOp>(loc, resultType, toAddr);
mlir::Value topos = builder.createConvert(loc, resultType, unbox(args[4]));
mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0);
mlir::Value ones = builder.createIntegerConstant(loc, resultType, -1);
mlir::Value bitSize = builder.createIntegerConstant(
loc, resultType, resultType.cast<mlir::IntegerType>().getWidth());
auto shiftCount = builder.create<mlir::arith::SubIOp>(loc, bitSize, len);
auto mask = builder.create<mlir::arith::ShRUIOp>(loc, ones, shiftCount);
auto unchangedTmp1 = builder.create<mlir::arith::ShLIOp>(loc, mask, topos);
auto unchangedTmp2 =
builder.create<mlir::arith::XOrIOp>(loc, unchangedTmp1, ones);
auto unchanged = builder.create<mlir::arith::AndIOp>(loc, unchangedTmp2, to);
auto frombitsTmp1 = builder.create<mlir::arith::ShRUIOp>(loc, from, frompos);
auto frombitsTmp2 =
builder.create<mlir::arith::AndIOp>(loc, frombitsTmp1, mask);
auto frombits = builder.create<mlir::arith::ShLIOp>(loc, frombitsTmp2, topos);
auto resTmp = builder.create<mlir::arith::OrIOp>(loc, unchanged, frombits);
auto lenIsZero = builder.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::eq, len, zero);
auto res = builder.create<mlir::arith::SelectOp>(loc, lenIsZero, to, resTmp);
builder.create<fir::StoreOp>(loc, res, toAddr);
}
// NEAREST
mlir::Value IntrinsicLibrary::genNearest(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {

View File

@ -0,0 +1,76 @@
! RUN: bbc -emit-fir %s -o - | FileCheck %s
! CHECK-LABEL: func @_QPmvbits_test(
function mvbits_test(from, frompos, len, to, topos)
! CHECK: %[[result:.*]] = fir.alloca i32 {bindc_name = "mvbits_test"
! CHECK-DAG: %[[from:.*]] = fir.load %arg0 : !fir.ref<i32>
! CHECK-DAG: %[[frompos:.*]] = fir.load %arg1 : !fir.ref<i32>
! CHECK-DAG: %[[len:.*]] = fir.load %arg2 : !fir.ref<i32>
! CHECK-DAG: %[[to:.*]] = fir.load %arg3 : !fir.ref<i32>
! CHECK-DAG: %[[topos:.*]] = fir.load %arg4 : !fir.ref<i32>
integer :: from, frompos, len, to, topos
integer :: mvbits_test
! CHECK: %[[VAL_11:.*]] = arith.constant 0 : i32
! CHECK: %[[VAL_12:.*]] = arith.constant -1 : i32
! CHECK: %[[VAL_13:.*]] = arith.constant 32 : i32
! CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_13]], %[[len]] : i32
! CHECK: %[[VAL_15:.*]] = arith.shrui %[[VAL_12]], %[[VAL_14]] : i32
! CHECK: %[[VAL_16:.*]] = arith.shli %[[VAL_15]], %[[topos]] : i32
! CHECK: %[[VAL_17:.*]] = arith.xori %[[VAL_16]], %[[VAL_12]] : i32
! CHECK: %[[VAL_18:.*]] = arith.andi %[[VAL_17]], %[[to]] : i32
! CHECK: %[[VAL_19:.*]] = arith.shrui %[[from]], %[[frompos]] : i32
! CHECK: %[[VAL_20:.*]] = arith.andi %[[VAL_19]], %[[VAL_15]] : i32
! CHECK: %[[VAL_21:.*]] = arith.shli %[[VAL_20]], %[[topos]] : i32
! CHECK: %[[VAL_22:.*]] = arith.ori %[[VAL_18]], %[[VAL_21]] : i32
! CHECK: %[[VAL_23:.*]] = arith.cmpi eq, %[[len]], %[[VAL_11]] : i32
! CHECK: %[[VAL_24:.*]] = arith.select %[[VAL_23]], %[[to]], %[[VAL_22]] : i32
! CHECK: fir.store %[[VAL_24]] to %arg3 : !fir.ref<i32>
! CHECK: %[[VAL_25:.*]] = fir.load %arg3 : !fir.ref<i32>
! CHECK: fir.store %[[VAL_25]] to %[[result]] : !fir.ref<i32>
call mvbits(from, frompos, len, to, topos)
! CHECK: %[[VAL_26:.*]] = fir.load %[[result]] : !fir.ref<i32>
! CHECK: return %[[VAL_26]] : i32
mvbits_test = to
end
! CHECK-LABEL: func @_QPmvbits_array_test(
! CHECK-SAME: %[[VAL_0:.*]]: !fir.box<!fir.array<?xi32>>{{.*}}, %[[VAL_1:.*]]: !fir.ref<i32>{{.*}}, %[[VAL_2:.*]]: !fir.ref<i32>{{.*}}, %[[VAL_3:.*]]: !fir.box<!fir.array<?xi32>>{{.*}}, %[[VAL_4:.*]]: !fir.ref<i32>{{.*}}) {
! CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
! CHECK: %[[VAL_6:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_5]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
! CHECK: %[[VAL_7:.*]] = fir.array_load %[[VAL_0]] : (!fir.box<!fir.array<?xi32>>) -> !fir.array<?xi32>
! CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_1]] : !fir.ref<i32>
! CHECK: %[[VAL_9:.*]] = fir.load %[[VAL_2]] : !fir.ref<i32>
! CHECK: %[[VAL_10:.*]] = fir.load %[[VAL_4]] : !fir.ref<i32>
! CHECK: %[[VAL_11:.*]] = arith.constant 1 : index
! CHECK: %[[VAL_12:.*]] = arith.constant 0 : index
! CHECK: %[[VAL_13:.*]] = arith.subi %[[VAL_6]]#1, %[[VAL_11]] : index
! CHECK: fir.do_loop %[[VAL_14:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_11]] {
! CHECK: %[[VAL_15:.*]] = fir.array_fetch %[[VAL_7]], %[[VAL_14]] : (!fir.array<?xi32>, index) -> i32
! CHECK: %[[VAL_16:.*]] = arith.constant 1 : index
! CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_14]], %[[VAL_16]] : index
! CHECK: %[[VAL_18:.*]] = fir.array_coor %[[VAL_3]] %[[VAL_17]] : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
! CHECK: %[[VAL_19:.*]] = fir.load %[[VAL_18]] : !fir.ref<i32>
! CHECK: %[[VAL_20:.*]] = arith.constant 0 : i32
! CHECK: %[[VAL_21:.*]] = arith.constant -1 : i32
! CHECK: %[[VAL_22:.*]] = arith.constant 32 : i32
! CHECK: %[[VAL_23:.*]] = arith.subi %[[VAL_22]], %[[VAL_9]] : i32
! CHECK: %[[VAL_24:.*]] = arith.shrui %[[VAL_21]], %[[VAL_23]] : i32
! CHECK: %[[VAL_25:.*]] = arith.shli %[[VAL_24]], %[[VAL_10]] : i32
! CHECK: %[[VAL_26:.*]] = arith.xori %[[VAL_25]], %[[VAL_21]] : i32
! CHECK: %[[VAL_27:.*]] = arith.andi %[[VAL_26]], %[[VAL_19]] : i32
! CHECK: %[[VAL_28:.*]] = arith.shrui %[[VAL_15]], %[[VAL_8]] : i32
! CHECK: %[[VAL_29:.*]] = arith.andi %[[VAL_28]], %[[VAL_24]] : i32
! CHECK: %[[VAL_30:.*]] = arith.shli %[[VAL_29]], %[[VAL_10]] : i32
! CHECK: %[[VAL_31:.*]] = arith.ori %[[VAL_27]], %[[VAL_30]] : i32
! CHECK: %[[VAL_32:.*]] = arith.cmpi eq, %[[VAL_9]], %[[VAL_20]] : i32
! CHECK: %[[VAL_33:.*]] = arith.select %[[VAL_32]], %[[VAL_19]], %[[VAL_31]] : i32
! CHECK: fir.store %[[VAL_33]] to %[[VAL_18]] : !fir.ref<i32>
! CHECK: }
! CHECK: return
! CHECK: }
subroutine mvbits_array_test(from, frompos, len, to, topos)
integer :: from(:), frompos, len, to(:), topos
call mvbits(from, frompos, len, to, topos)
end subroutine