[mlir][Vector] Support 0-D vectors in `VectorPrintOpConversion`

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D114549
This commit is contained in:
Michal Terepeta 2021-11-25 20:10:02 +00:00 committed by Nicolas Vasilache
parent 0796869e4e
commit cc311a155a
3 changed files with 56 additions and 10 deletions

View File

@ -57,8 +57,7 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
static Value extractOne(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc,
Value val, Type llvmType, int64_t rank, int64_t pos) {
assert(rank > 0 && "0-D vector corner case should have been handled already");
if (rank == 1) {
if (rank <= 1) {
auto idxType = rewriter.getIndexType();
auto constant = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter.convertType(idxType),
@ -987,7 +986,8 @@ public:
// Unroll vector into elementary print calls.
int64_t rank = vectorType ? vectorType.getRank() : 0;
emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
Type type = vectorType ? vectorType : eltType;
emitRanks(rewriter, printOp, adaptor.source(), type, printer, rank,
conversion);
emitCall(rewriter, printOp->getLoc(),
LLVM::lookupOrCreatePrintNewlineFn(
@ -1006,10 +1006,12 @@ private:
};
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
Value value, VectorType vectorType, Operation *printer,
int64_t rank, PrintConversion conversion) const {
Value value, Type type, Operation *printer, int64_t rank,
PrintConversion conversion) const {
VectorType vectorType = type.dyn_cast<VectorType>();
Location loc = op->getLoc();
if (rank == 0) {
if (!vectorType) {
assert(rank == 0 && "The scalar case expects rank == 0");
switch (conversion) {
case PrintConversion::ZeroExt64:
value = rewriter.create<arith::ExtUIOp>(
@ -1030,12 +1032,29 @@ private:
LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
Operation *printComma =
LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
if (rank <= 1) {
auto reducedType = vectorType.getElementType();
auto llvmType = typeConverter->convertType(reducedType);
int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
for (int64_t d = 0; d < dim; ++d) {
Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
llvmType, /*rank=*/0, /*pos=*/d);
emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
conversion);
if (d != dim - 1)
emitCall(rewriter, loc, printComma);
}
emitCall(
rewriter, loc,
LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
return;
}
int64_t dim = vectorType.getDimSize(0);
for (int64_t d = 0; d < dim; ++d) {
auto reducedType =
rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
auto llvmType = typeConverter->convertType(
rank > 1 ? reducedType : vectorType.getElementType());
auto reducedType = reducedVectorTypeFront(vectorType);
auto llvmType = typeConverter->convertType(reducedType);
Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
llvmType, rank, d);
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,

View File

@ -832,6 +832,23 @@ func @vector_print_scalar_f64(%arg0: f64) {
// -----
func @vector_print_vector_0d(%arg0: vector<f32>) {
vector.print %arg0 : vector<f32>
return
}
// CHECK-LABEL: @vector_print_vector_0d(
// CHECK-SAME: %[[A:.*]]: vector<f32>)
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
// CHECK: llvm.call @printOpen() : () -> ()
// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<1xf32>
// CHECK: llvm.call @printF32(%[[T2]]) : (f32) -> ()
// CHECK: llvm.call @printClose() : () -> ()
// CHECK: llvm.call @printNewline() : () -> ()
// CHECK: return
// -----
func @vector_print_vector(%arg0: vector<2x2xf32>) {
vector.print %arg0 : vector<2x2xf32>
return

View File

@ -15,10 +15,20 @@ func @insert_element_0d(%a: f32, %b: vector<f32>) -> (vector<f32>) {
return %1: vector<f32>
}
func @print_vector_0d(%a: vector<f32>) {
// CHECK: ( 42 )
vector.print %a: vector<f32>
return
}
func @entry() {
%0 = arith.constant 42.0 : f32
%1 = arith.constant dense<0.0> : vector<f32>
%2 = call @insert_element_0d(%0, %1) : (f32, vector<f32>) -> (vector<f32>)
call @extract_element_0d(%2) : (vector<f32>) -> ()
%3 = arith.constant dense<42.0> : vector<f32>
call @print_vector_0d(%3) : (vector<f32>) -> ()
return
}