forked from OSchip/llvm-project
[flang] Add TODO for half-precision intrinsic reductions
Add TODO for half-precision for reduction. This patch is part of the upstreaming effort from fir-dev branch. Reviewed By: jeanPerier, PeteSteinfeld Differential Revision: https://reviews.llvm.org/D127622 Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
This commit is contained in:
parent
2b89a4dc51
commit
4a8305ce85
|
@ -545,7 +545,9 @@ mlir::Value fir::runtime::genMaxval(fir::FirOpBuilder &builder,
|
|||
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
|
||||
auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
|
||||
|
||||
if (eleTy.isF32())
|
||||
if (eleTy.isF16() || eleTy.isBF16())
|
||||
TODO(loc, "half-precision MAXVAL");
|
||||
else if (eleTy.isF32())
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalReal4)>(loc, builder);
|
||||
else if (eleTy.isF64())
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalReal8)>(loc, builder);
|
||||
|
@ -553,23 +555,18 @@ mlir::Value fir::runtime::genMaxval(fir::FirOpBuilder &builder,
|
|||
func = fir::runtime::getRuntimeFunc<ForcedMaxvalReal10>(loc, builder);
|
||||
else if (eleTy.isF128())
|
||||
func = fir::runtime::getRuntimeFunc<ForcedMaxvalReal16>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger1)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger2)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger4)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger8)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
|
||||
func = fir::runtime::getRuntimeFunc<ForcedMaxvalInteger16>(loc, builder);
|
||||
else
|
||||
fir::emitFatalError(loc, "invalid type in Maxval lowering");
|
||||
fir::emitFatalError(loc, "invalid type in MAXVAL");
|
||||
|
||||
auto fTy = func.getFunctionType();
|
||||
auto sourceFile = fir::factory::locationToFilename(builder, loc);
|
||||
|
@ -664,7 +661,9 @@ mlir::Value fir::runtime::genMinval(fir::FirOpBuilder &builder,
|
|||
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
|
||||
auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
|
||||
|
||||
if (eleTy.isF32())
|
||||
if (eleTy.isF16() || eleTy.isBF16())
|
||||
TODO(loc, "half-precision MINVAL");
|
||||
else if (eleTy.isF32())
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalReal4)>(loc, builder);
|
||||
else if (eleTy.isF64())
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalReal8)>(loc, builder);
|
||||
|
@ -672,23 +671,18 @@ mlir::Value fir::runtime::genMinval(fir::FirOpBuilder &builder,
|
|||
func = fir::runtime::getRuntimeFunc<ForcedMinvalReal10>(loc, builder);
|
||||
else if (eleTy.isF128())
|
||||
func = fir::runtime::getRuntimeFunc<ForcedMinvalReal16>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger1)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger2)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger4)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger8)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
|
||||
func = fir::runtime::getRuntimeFunc<ForcedMinvalInteger16>(loc, builder);
|
||||
else
|
||||
fir::emitFatalError(loc, "invalid type in Minval lowering");
|
||||
fir::emitFatalError(loc, "invalid type in MINVAL");
|
||||
|
||||
auto fTy = func.getFunctionType();
|
||||
auto sourceFile = fir::factory::locationToFilename(builder, loc);
|
||||
|
@ -721,7 +715,9 @@ mlir::Value fir::runtime::genProduct(fir::FirOpBuilder &builder,
|
|||
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
|
||||
auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
|
||||
|
||||
if (eleTy.isF32())
|
||||
if (eleTy.isF16() || eleTy.isBF16())
|
||||
TODO(loc, "half-precision PRODUCT");
|
||||
else if (eleTy.isF32())
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(ProductReal4)>(loc, builder);
|
||||
else if (eleTy.isF64())
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(ProductReal8)>(loc, builder);
|
||||
|
@ -729,20 +725,15 @@ mlir::Value fir::runtime::genProduct(fir::FirOpBuilder &builder,
|
|||
func = fir::runtime::getRuntimeFunc<ForcedProductReal10>(loc, builder);
|
||||
else if (eleTy.isF128())
|
||||
func = fir::runtime::getRuntimeFunc<ForcedProductReal16>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger1)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger2)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger4)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger8)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
|
||||
func = fir::runtime::getRuntimeFunc<ForcedProductInteger16>(loc, builder);
|
||||
else if (eleTy == fir::ComplexType::get(builder.getContext(), 4))
|
||||
func =
|
||||
|
@ -754,8 +745,11 @@ mlir::Value fir::runtime::genProduct(fir::FirOpBuilder &builder,
|
|||
func = fir::runtime::getRuntimeFunc<ForcedProductComplex10>(loc, builder);
|
||||
else if (eleTy == fir::ComplexType::get(builder.getContext(), 16))
|
||||
func = fir::runtime::getRuntimeFunc<ForcedProductComplex16>(loc, builder);
|
||||
else if (eleTy == fir::ComplexType::get(builder.getContext(), 2) ||
|
||||
eleTy == fir::ComplexType::get(builder.getContext(), 3))
|
||||
TODO(loc, "half-precision PRODUCT");
|
||||
else
|
||||
fir::emitFatalError(loc, "invalid type in Product lowering");
|
||||
fir::emitFatalError(loc, "invalid type in PRODUCT");
|
||||
|
||||
auto fTy = func.getFunctionType();
|
||||
auto sourceFile = fir::factory::locationToFilename(builder, loc);
|
||||
|
@ -788,7 +782,9 @@ mlir::Value fir::runtime::genDotProduct(fir::FirOpBuilder &builder,
|
|||
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
|
||||
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
|
||||
|
||||
if (eleTy.isF32())
|
||||
if (eleTy.isF16() || eleTy.isBF16())
|
||||
TODO(loc, "half-precision DOTPRODUCT");
|
||||
else if (eleTy.isF32())
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(DotProductReal4)>(loc, builder);
|
||||
else if (eleTy.isF64())
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(DotProductReal8)>(loc, builder);
|
||||
|
@ -808,31 +804,29 @@ mlir::Value fir::runtime::genDotProduct(fir::FirOpBuilder &builder,
|
|||
else if (eleTy == fir::ComplexType::get(builder.getContext(), 16))
|
||||
func =
|
||||
fir::runtime::getRuntimeFunc<ForcedDotProductComplex16>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
|
||||
else if (eleTy == fir::ComplexType::get(builder.getContext(), 2) ||
|
||||
eleTy == fir::ComplexType::get(builder.getContext(), 3))
|
||||
TODO(loc, "half-precision DOTPRODUCT");
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
|
||||
func =
|
||||
fir::runtime::getRuntimeFunc<mkRTKey(DotProductInteger1)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
|
||||
func =
|
||||
fir::runtime::getRuntimeFunc<mkRTKey(DotProductInteger2)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
|
||||
func =
|
||||
fir::runtime::getRuntimeFunc<mkRTKey(DotProductInteger4)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
|
||||
func =
|
||||
fir::runtime::getRuntimeFunc<mkRTKey(DotProductInteger8)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
|
||||
func =
|
||||
fir::runtime::getRuntimeFunc<ForcedDotProductInteger16>(loc, builder);
|
||||
else if (eleTy.isa<fir::LogicalType>())
|
||||
func =
|
||||
fir::runtime::getRuntimeFunc<mkRTKey(DotProductLogical)>(loc, builder);
|
||||
else
|
||||
fir::emitFatalError(loc, "invalid type in DotProduct lowering");
|
||||
fir::emitFatalError(loc, "invalid type in DOTPRODUCT");
|
||||
|
||||
auto fTy = func.getFunctionType();
|
||||
auto sourceFile = fir::factory::locationToFilename(builder, loc);
|
||||
|
@ -873,7 +867,9 @@ mlir::Value fir::runtime::genSum(fir::FirOpBuilder &builder, mlir::Location loc,
|
|||
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
|
||||
auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
|
||||
|
||||
if (eleTy.isF32())
|
||||
if (eleTy.isF16() || eleTy.isBF16())
|
||||
TODO(loc, "half-precision SUM");
|
||||
else if (eleTy.isF32())
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(SumReal4)>(loc, builder);
|
||||
else if (eleTy.isF64())
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(SumReal8)>(loc, builder);
|
||||
|
@ -881,20 +877,15 @@ mlir::Value fir::runtime::genSum(fir::FirOpBuilder &builder, mlir::Location loc,
|
|||
func = fir::runtime::getRuntimeFunc<ForcedSumReal10>(loc, builder);
|
||||
else if (eleTy.isF128())
|
||||
func = fir::runtime::getRuntimeFunc<ForcedSumReal16>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(SumInteger1)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(SumInteger2)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(SumInteger4)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(SumInteger8)>(loc, builder);
|
||||
else if (eleTy ==
|
||||
builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
|
||||
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
|
||||
func = fir::runtime::getRuntimeFunc<ForcedSumInteger16>(loc, builder);
|
||||
else if (eleTy == fir::ComplexType::get(builder.getContext(), 4))
|
||||
func = fir::runtime::getRuntimeFunc<mkRTKey(CppSumComplex4)>(loc, builder);
|
||||
|
@ -904,8 +895,11 @@ mlir::Value fir::runtime::genSum(fir::FirOpBuilder &builder, mlir::Location loc,
|
|||
func = fir::runtime::getRuntimeFunc<ForcedSumComplex10>(loc, builder);
|
||||
else if (eleTy == fir::ComplexType::get(builder.getContext(), 16))
|
||||
func = fir::runtime::getRuntimeFunc<ForcedSumComplex16>(loc, builder);
|
||||
else if (eleTy == fir::ComplexType::get(builder.getContext(), 2) ||
|
||||
eleTy == fir::ComplexType::get(builder.getContext(), 3))
|
||||
TODO(loc, "half-precision SUM");
|
||||
else
|
||||
fir::emitFatalError(loc, "invalid type in Sum lowering");
|
||||
fir::emitFatalError(loc, "invalid type in SUM");
|
||||
|
||||
auto fTy = func.getFunctionType();
|
||||
auto sourceFile = fir::factory::locationToFilename(builder, loc);
|
||||
|
|
Loading…
Reference in New Issue