[mlir] [VectorOps] generalize printing support for integers

This generalizes printing beyond just i1,i32,i64 and also accounts
for signed and unsigned interpretation in the output.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D88290
This commit is contained in:
Aart Bik 2020-09-25 03:32:05 -07:00
parent f330d9f163
commit b8880f5f97
4 changed files with 242 additions and 24 deletions

View File

@ -0,0 +1,76 @@
// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
//
// Test various signless, signed, unsigned integer types.
//
func @entry() {
%0 = std.constant dense<[true, false, -1, 0, 1]> : vector<5xi1>
vector.print %0 : vector<5xi1>
// CHECK: ( 1, 0, 1, 0, 1 )
%1 = std.constant dense<[true, false, -1, 0]> : vector<4xsi1>
vector.print %1 : vector<4xsi1>
// CHECK: ( 1, 0, 1, 0 )
%2 = std.constant dense<[true, false, 0, 1]> : vector<4xui1>
vector.print %2 : vector<4xui1>
// CHECK: ( 1, 0, 0, 1 )
%3 = std.constant dense<[-128, -127, -1, 0, 1, 127, 128, 254, 255]> : vector<9xi8>
vector.print %3 : vector<9xi8>
// CHECK: ( -128, -127, -1, 0, 1, 127, -128, -2, -1 )
%4 = std.constant dense<[-128, -127, -1, 0, 1, 127]> : vector<6xsi8>
vector.print %4 : vector<6xsi8>
// CHECK: ( -128, -127, -1, 0, 1, 127 )
%5 = std.constant dense<[0, 1, 127, 128, 254, 255]> : vector<6xui8>
vector.print %5 : vector<6xui8>
// CHECK: ( 0, 1, 127, 128, 254, 255 )
%6 = std.constant dense<[-32768, -32767, -1, 0, 1, 32767, 32768, 65534, 65535]> : vector<9xi16>
vector.print %6 : vector<9xi16>
// CHECK: ( -32768, -32767, -1, 0, 1, 32767, -32768, -2, -1 )
%7 = std.constant dense<[-32768, -32767, -1, 0, 1, 32767]> : vector<6xsi16>
vector.print %7 : vector<6xsi16>
// CHECK: ( -32768, -32767, -1, 0, 1, 32767 )
%8 = std.constant dense<[0, 1, 32767, 32768, 65534, 65535]> : vector<6xui16>
vector.print %8 : vector<6xui16>
// CHECK: ( 0, 1, 32767, 32768, 65534, 65535 )
%9 = std.constant dense<[-2147483648, -2147483647, -1, 0, 1,
2147483647, 2147483648, 4294967294, 4294967295]> : vector<9xi32>
vector.print %9 : vector<9xi32>
// CHECK: ( -2147483648, -2147483647, -1, 0, 1, 2147483647, -2147483648, -2, -1 )
%10 = std.constant dense<[-2147483648, -2147483647, -1, 0, 1, 2147483647]> : vector<6xsi32>
vector.print %10 : vector<6xsi32>
// CHECK: ( -2147483648, -2147483647, -1, 0, 1, 2147483647 )
%11 = std.constant dense<[0, 1, 2147483647, 2147483648, 4294967294, 4294967295]> : vector<6xui32>
vector.print %11 : vector<6xui32>
// CHECK: ( 0, 1, 2147483647, 2147483648, 4294967294, 4294967295 )
%12 = std.constant dense<[-9223372036854775808, -9223372036854775807, -1, 0, 1,
9223372036854775807, 9223372036854775808,
18446744073709551614, 18446744073709551615]> : vector<9xi64>
vector.print %12 : vector<9xi64>
// CHECK: ( -9223372036854775808, -9223372036854775807, -1, 0, 1, 9223372036854775807, -9223372036854775808, -2, -1 )
%13 = std.constant dense<[-9223372036854775808, -9223372036854775807, -1, 0, 1,
9223372036854775807]> : vector<6xsi64>
vector.print %13 : vector<6xsi64>
// CHECK: ( -9223372036854775808, -9223372036854775807, -1, 0, 1, 9223372036854775807 )
%14 = std.constant dense<[0, 1, 9223372036854775807, 9223372036854775808,
18446744073709551614, 18446744073709551615]> : vector<6xui64>
vector.print %14 : vector<6xui64>
// CHECK: ( 0, 1, 9223372036854775807, 9223372036854775808, 18446744073709551614, 18446744073709551615 )
return
}

View File

@ -1319,44 +1319,96 @@ public:
if (typeConverter.convertType(printType) == nullptr)
return failure();
// Make sure element type has runtime support (currently just Float/Double).
// Make sure element type has runtime support.
PrintConversion conversion = PrintConversion::None;
VectorType vectorType = printType.dyn_cast<VectorType>();
Type eltType = vectorType ? vectorType.getElementType() : printType;
int64_t rank = vectorType ? vectorType.getRank() : 0;
Operation *printer;
if (eltType.isSignlessInteger(1) || eltType.isSignlessInteger(32))
printer = getPrintI32(op);
else if (eltType.isSignlessInteger(64))
printer = getPrintI64(op);
else if (eltType.isF32())
if (eltType.isF32()) {
printer = getPrintFloat(op);
else if (eltType.isF64())
} else if (eltType.isF64()) {
printer = getPrintDouble(op);
else
} else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
// Integers need a zero or sign extension on the operand
// (depending on the source type) as well as a signed or
// unsigned print method. Up to 64-bit is supported.
unsigned width = intTy.getWidth();
if (intTy.isUnsigned()) {
if (width <= 32) {
if (width < 32)
conversion = PrintConversion::ZeroExt32;
printer = getPrintU32(op);
} else if (width <= 64) {
if (width < 64)
conversion = PrintConversion::ZeroExt64;
printer = getPrintU64(op);
} else {
return failure();
}
} else {
assert(intTy.isSignless() || intTy.isSigned());
if (width <= 32) {
// Note that we *always* zero extend booleans (1-bit integers),
// so that true/false is printed as 1/0 rather than -1/0.
if (width == 1)
conversion = PrintConversion::ZeroExt32;
else if (width < 32)
conversion = PrintConversion::SignExt32;
printer = getPrintI32(op);
} else if (width <= 64) {
if (width < 64)
conversion = PrintConversion::SignExt64;
printer = getPrintI64(op);
} else {
return failure();
}
}
} else {
return failure();
}
// Unroll vector into elementary print calls.
emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
int64_t rank = vectorType ? vectorType.getRank() : 0;
emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank,
conversion);
emitCall(rewriter, op->getLoc(), getPrintNewline(op));
rewriter.eraseOp(op);
return success();
}
private:
enum class PrintConversion {
None,
ZeroExt32,
SignExt32,
ZeroExt64,
SignExt64
};
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
Value value, VectorType vectorType, Operation *printer,
int64_t rank) const {
int64_t rank, PrintConversion conversion) const {
Location loc = op->getLoc();
if (rank == 0) {
if (value.getType() == LLVM::LLVMType::getInt1Ty(rewriter.getContext())) {
// Convert i1 (bool) to i32 so we can use the print_i32 method.
// This avoids the need for a print_i1 method with an unclear ABI.
auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
auto trueVal = rewriter.create<ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(1));
auto falseVal = rewriter.create<ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(0));
value = rewriter.create<SelectOp>(loc, value, trueVal, falseVal);
switch (conversion) {
case PrintConversion::ZeroExt32:
value = rewriter.create<ZeroExtendIOp>(
loc, value, LLVM::LLVMType::getInt32Ty(rewriter.getContext()));
break;
case PrintConversion::SignExt32:
value = rewriter.create<SignExtendIOp>(
loc, value, LLVM::LLVMType::getInt32Ty(rewriter.getContext()));
break;
case PrintConversion::ZeroExt64:
value = rewriter.create<ZeroExtendIOp>(
loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
break;
case PrintConversion::SignExt64:
value = rewriter.create<SignExtendIOp>(
loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
break;
case PrintConversion::None:
break;
}
emitCall(rewriter, loc, printer, value);
return;
@ -1372,7 +1424,8 @@ private:
rank > 1 ? reducedType : vectorType.getElementType());
Value nestedVal =
extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
conversion);
if (d != dim - 1)
emitCall(rewriter, loc, printComma);
}
@ -1410,6 +1463,14 @@ private:
return getPrint(op, "print_i64",
LLVM::LLVMType::getInt64Ty(op->getContext()));
}
Operation *getPrintU32(Operation *op) const {
return getPrint(op, "printU32",
LLVM::LLVMType::getInt32Ty(op->getContext()));
}
Operation *getPrintU64(Operation *op) const {
return getPrint(op, "printU64",
LLVM::LLVMType::getInt64Ty(op->getContext()));
}
Operation *getPrintFloat(Operation *op) const {
return getPrint(op, "print_f32",
LLVM::LLVMType::getFloatTy(op->getContext()));

View File

@ -25,6 +25,8 @@
// details of our vectors. Also useful for direct LLVM IR output.
extern "C" void print_i32(int32_t i) { fprintf(stdout, "%" PRId32, i); }
extern "C" void print_i64(int64_t l) { fprintf(stdout, "%" PRId64, l); }
extern "C" void printU32(uint32_t i) { fprintf(stdout, "%" PRIu32, i); }
extern "C" void printU64(uint64_t l) { fprintf(stdout, "%" PRIu64, l); }
extern "C" void print_f32(float f) { fprintf(stdout, "%g", f); }
extern "C" void print_f64(double d) { fprintf(stdout, "%lg", d); }
extern "C" void print_open() { fputs("( ", stdout); }

View File

@ -433,14 +433,45 @@ func @vector_print_scalar_i1(%arg0: i1) {
vector.print %arg0 : i1
return
}
//
// Type "boolean" always uses zero extension.
//
// CHECK-LABEL: llvm.func @vector_print_scalar_i1(
// CHECK-SAME: %[[A:.*]]: !llvm.i1)
// CHECK: %[[T:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
// CHECK: %[[F:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: %[[S:.*]] = llvm.select %[[A]], %[[T]], %[[F]] : !llvm.i1, !llvm.i32
// CHECK: %[[S:.*]] = llvm.zext %[[A]] : !llvm.i1 to !llvm.i32
// CHECK: llvm.call @print_i32(%[[S]]) : (!llvm.i32) -> ()
// CHECK: llvm.call @print_newline() : () -> ()
func @vector_print_scalar_i4(%arg0: i4) {
vector.print %arg0 : i4
return
}
// CHECK-LABEL: llvm.func @vector_print_scalar_i4(
// CHECK-SAME: %[[A:.*]]: !llvm.i4)
// CHECK: %[[S:.*]] = llvm.sext %[[A]] : !llvm.i4 to !llvm.i32
// CHECK: llvm.call @print_i32(%[[S]]) : (!llvm.i32) -> ()
// CHECK: llvm.call @print_newline() : () -> ()
func @vector_print_scalar_si4(%arg0: si4) {
vector.print %arg0 : si4
return
}
// CHECK-LABEL: llvm.func @vector_print_scalar_si4(
// CHECK-SAME: %[[A:.*]]: !llvm.i4)
// CHECK: %[[S:.*]] = llvm.sext %[[A]] : !llvm.i4 to !llvm.i32
// CHECK: llvm.call @print_i32(%[[S]]) : (!llvm.i32) -> ()
// CHECK: llvm.call @print_newline() : () -> ()
func @vector_print_scalar_ui4(%arg0: ui4) {
vector.print %arg0 : ui4
return
}
// CHECK-LABEL: llvm.func @vector_print_scalar_ui4(
// CHECK-SAME: %[[A:.*]]: !llvm.i4)
// CHECK: %[[S:.*]] = llvm.zext %[[A]] : !llvm.i4 to !llvm.i32
// CHECK: llvm.call @printU32(%[[S]]) : (!llvm.i32) -> ()
// CHECK: llvm.call @print_newline() : () -> ()
func @vector_print_scalar_i32(%arg0: i32) {
vector.print %arg0 : i32
return
@ -450,6 +481,45 @@ func @vector_print_scalar_i32(%arg0: i32) {
// CHECK: llvm.call @print_i32(%[[A]]) : (!llvm.i32) -> ()
// CHECK: llvm.call @print_newline() : () -> ()
func @vector_print_scalar_ui32(%arg0: ui32) {
vector.print %arg0 : ui32
return
}
// CHECK-LABEL: llvm.func @vector_print_scalar_ui32(
// CHECK-SAME: %[[A:.*]]: !llvm.i32)
// CHECK: llvm.call @printU32(%[[A]]) : (!llvm.i32) -> ()
// CHECK: llvm.call @print_newline() : () -> ()
func @vector_print_scalar_i40(%arg0: i40) {
vector.print %arg0 : i40
return
}
// CHECK-LABEL: llvm.func @vector_print_scalar_i40(
// CHECK-SAME: %[[A:.*]]: !llvm.i40)
// CHECK: %[[S:.*]] = llvm.sext %[[A]] : !llvm.i40 to !llvm.i64
// CHECK: llvm.call @print_i64(%[[S]]) : (!llvm.i64) -> ()
// CHECK: llvm.call @print_newline() : () -> ()
func @vector_print_scalar_si40(%arg0: si40) {
vector.print %arg0 : si40
return
}
// CHECK-LABEL: llvm.func @vector_print_scalar_si40(
// CHECK-SAME: %[[A:.*]]: !llvm.i40)
// CHECK: %[[S:.*]] = llvm.sext %[[A]] : !llvm.i40 to !llvm.i64
// CHECK: llvm.call @print_i64(%[[S]]) : (!llvm.i64) -> ()
// CHECK: llvm.call @print_newline() : () -> ()
func @vector_print_scalar_ui40(%arg0: ui40) {
vector.print %arg0 : ui40
return
}
// CHECK-LABEL: llvm.func @vector_print_scalar_ui40(
// CHECK-SAME: %[[A:.*]]: !llvm.i40)
// CHECK: %[[S:.*]] = llvm.zext %[[A]] : !llvm.i40 to !llvm.i64
// CHECK: llvm.call @printU64(%[[S]]) : (!llvm.i64) -> ()
// CHECK: llvm.call @print_newline() : () -> ()
func @vector_print_scalar_i64(%arg0: i64) {
vector.print %arg0 : i64
return
@ -459,6 +529,15 @@ func @vector_print_scalar_i64(%arg0: i64) {
// CHECK: llvm.call @print_i64(%[[A]]) : (!llvm.i64) -> ()
// CHECK: llvm.call @print_newline() : () -> ()
func @vector_print_scalar_ui64(%arg0: ui64) {
vector.print %arg0 : ui64
return
}
// CHECK-LABEL: llvm.func @vector_print_scalar_ui64(
// CHECK-SAME: %[[A:.*]]: !llvm.i64)
// CHECK: llvm.call @printU64(%[[A]]) : (!llvm.i64) -> ()
// CHECK: llvm.call @print_newline() : () -> ()
func @vector_print_scalar_f32(%arg0: f32) {
vector.print %arg0 : f32
return