[flang] Add TODO for half-precision intrinsic reductions

Add TODO for half-precision for reduction.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: jeanPerier, PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D127622

Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
This commit is contained in:
Valentin Clement 2022-06-13 17:39:15 +02:00
parent 2b89a4dc51
commit 4a8305ce85
No known key found for this signature in database
GPG Key ID: 086D54783C928776
1 changed files with 54 additions and 60 deletions

View File

@ -545,7 +545,9 @@ mlir::Value fir::runtime::genMaxval(fir::FirOpBuilder &builder,
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
if (eleTy.isF32())
if (eleTy.isF16() || eleTy.isBF16())
TODO(loc, "half-precision MAXVAL");
else if (eleTy.isF32())
func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalReal4)>(loc, builder);
else if (eleTy.isF64())
func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalReal8)>(loc, builder);
@ -553,23 +555,18 @@ mlir::Value fir::runtime::genMaxval(fir::FirOpBuilder &builder,
func = fir::runtime::getRuntimeFunc<ForcedMaxvalReal10>(loc, builder);
else if (eleTy.isF128())
func = fir::runtime::getRuntimeFunc<ForcedMaxvalReal16>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger1)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger2)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger4)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger8)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
func = fir::runtime::getRuntimeFunc<ForcedMaxvalInteger16>(loc, builder);
else
fir::emitFatalError(loc, "invalid type in Maxval lowering");
fir::emitFatalError(loc, "invalid type in MAXVAL");
auto fTy = func.getFunctionType();
auto sourceFile = fir::factory::locationToFilename(builder, loc);
@ -664,7 +661,9 @@ mlir::Value fir::runtime::genMinval(fir::FirOpBuilder &builder,
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
if (eleTy.isF32())
if (eleTy.isF16() || eleTy.isBF16())
TODO(loc, "half-precision MINVAL");
else if (eleTy.isF32())
func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalReal4)>(loc, builder);
else if (eleTy.isF64())
func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalReal8)>(loc, builder);
@ -672,23 +671,18 @@ mlir::Value fir::runtime::genMinval(fir::FirOpBuilder &builder,
func = fir::runtime::getRuntimeFunc<ForcedMinvalReal10>(loc, builder);
else if (eleTy.isF128())
func = fir::runtime::getRuntimeFunc<ForcedMinvalReal16>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger1)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger2)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger4)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger8)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
func = fir::runtime::getRuntimeFunc<ForcedMinvalInteger16>(loc, builder);
else
fir::emitFatalError(loc, "invalid type in Minval lowering");
fir::emitFatalError(loc, "invalid type in MINVAL");
auto fTy = func.getFunctionType();
auto sourceFile = fir::factory::locationToFilename(builder, loc);
@ -721,7 +715,9 @@ mlir::Value fir::runtime::genProduct(fir::FirOpBuilder &builder,
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
if (eleTy.isF32())
if (eleTy.isF16() || eleTy.isBF16())
TODO(loc, "half-precision PRODUCT");
else if (eleTy.isF32())
func = fir::runtime::getRuntimeFunc<mkRTKey(ProductReal4)>(loc, builder);
else if (eleTy.isF64())
func = fir::runtime::getRuntimeFunc<mkRTKey(ProductReal8)>(loc, builder);
@ -729,20 +725,15 @@ mlir::Value fir::runtime::genProduct(fir::FirOpBuilder &builder,
func = fir::runtime::getRuntimeFunc<ForcedProductReal10>(loc, builder);
else if (eleTy.isF128())
func = fir::runtime::getRuntimeFunc<ForcedProductReal16>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger1)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger2)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger4)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger8)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
func = fir::runtime::getRuntimeFunc<ForcedProductInteger16>(loc, builder);
else if (eleTy == fir::ComplexType::get(builder.getContext(), 4))
func =
@ -754,8 +745,11 @@ mlir::Value fir::runtime::genProduct(fir::FirOpBuilder &builder,
func = fir::runtime::getRuntimeFunc<ForcedProductComplex10>(loc, builder);
else if (eleTy == fir::ComplexType::get(builder.getContext(), 16))
func = fir::runtime::getRuntimeFunc<ForcedProductComplex16>(loc, builder);
else if (eleTy == fir::ComplexType::get(builder.getContext(), 2) ||
eleTy == fir::ComplexType::get(builder.getContext(), 3))
TODO(loc, "half-precision PRODUCT");
else
fir::emitFatalError(loc, "invalid type in Product lowering");
fir::emitFatalError(loc, "invalid type in PRODUCT");
auto fTy = func.getFunctionType();
auto sourceFile = fir::factory::locationToFilename(builder, loc);
@ -788,7 +782,9 @@ mlir::Value fir::runtime::genDotProduct(fir::FirOpBuilder &builder,
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
if (eleTy.isF32())
if (eleTy.isF16() || eleTy.isBF16())
TODO(loc, "half-precision DOTPRODUCT");
else if (eleTy.isF32())
func = fir::runtime::getRuntimeFunc<mkRTKey(DotProductReal4)>(loc, builder);
else if (eleTy.isF64())
func = fir::runtime::getRuntimeFunc<mkRTKey(DotProductReal8)>(loc, builder);
@ -808,31 +804,29 @@ mlir::Value fir::runtime::genDotProduct(fir::FirOpBuilder &builder,
else if (eleTy == fir::ComplexType::get(builder.getContext(), 16))
func =
fir::runtime::getRuntimeFunc<ForcedDotProductComplex16>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
else if (eleTy == fir::ComplexType::get(builder.getContext(), 2) ||
eleTy == fir::ComplexType::get(builder.getContext(), 3))
TODO(loc, "half-precision DOTPRODUCT");
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
func =
fir::runtime::getRuntimeFunc<mkRTKey(DotProductInteger1)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
func =
fir::runtime::getRuntimeFunc<mkRTKey(DotProductInteger2)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
func =
fir::runtime::getRuntimeFunc<mkRTKey(DotProductInteger4)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
func =
fir::runtime::getRuntimeFunc<mkRTKey(DotProductInteger8)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
func =
fir::runtime::getRuntimeFunc<ForcedDotProductInteger16>(loc, builder);
else if (eleTy.isa<fir::LogicalType>())
func =
fir::runtime::getRuntimeFunc<mkRTKey(DotProductLogical)>(loc, builder);
else
fir::emitFatalError(loc, "invalid type in DotProduct lowering");
fir::emitFatalError(loc, "invalid type in DOTPRODUCT");
auto fTy = func.getFunctionType();
auto sourceFile = fir::factory::locationToFilename(builder, loc);
@ -873,7 +867,9 @@ mlir::Value fir::runtime::genSum(fir::FirOpBuilder &builder, mlir::Location loc,
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
if (eleTy.isF32())
if (eleTy.isF16() || eleTy.isBF16())
TODO(loc, "half-precision SUM");
else if (eleTy.isF32())
func = fir::runtime::getRuntimeFunc<mkRTKey(SumReal4)>(loc, builder);
else if (eleTy.isF64())
func = fir::runtime::getRuntimeFunc<mkRTKey(SumReal8)>(loc, builder);
@ -881,20 +877,15 @@ mlir::Value fir::runtime::genSum(fir::FirOpBuilder &builder, mlir::Location loc,
func = fir::runtime::getRuntimeFunc<ForcedSumReal10>(loc, builder);
else if (eleTy.isF128())
func = fir::runtime::getRuntimeFunc<ForcedSumReal16>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
func = fir::runtime::getRuntimeFunc<mkRTKey(SumInteger1)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
func = fir::runtime::getRuntimeFunc<mkRTKey(SumInteger2)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
func = fir::runtime::getRuntimeFunc<mkRTKey(SumInteger4)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
func = fir::runtime::getRuntimeFunc<mkRTKey(SumInteger8)>(loc, builder);
else if (eleTy ==
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
func = fir::runtime::getRuntimeFunc<ForcedSumInteger16>(loc, builder);
else if (eleTy == fir::ComplexType::get(builder.getContext(), 4))
func = fir::runtime::getRuntimeFunc<mkRTKey(CppSumComplex4)>(loc, builder);
@ -904,8 +895,11 @@ mlir::Value fir::runtime::genSum(fir::FirOpBuilder &builder, mlir::Location loc,
func = fir::runtime::getRuntimeFunc<ForcedSumComplex10>(loc, builder);
else if (eleTy == fir::ComplexType::get(builder.getContext(), 16))
func = fir::runtime::getRuntimeFunc<ForcedSumComplex16>(loc, builder);
else if (eleTy == fir::ComplexType::get(builder.getContext(), 2) ||
eleTy == fir::ComplexType::get(builder.getContext(), 3))
TODO(loc, "half-precision SUM");
else
fir::emitFatalError(loc, "invalid type in Sum lowering");
fir::emitFatalError(loc, "invalid type in SUM");
auto fTy = func.getFunctionType();
auto sourceFile = fir::factory::locationToFilename(builder, loc);