forked from OSchip/llvm-project
[mlir][Vector] Support 0-D vectors in `VectorPrintOpConversion`
Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D114549
This commit is contained in:
parent
0796869e4e
commit
cc311a155a
|
@ -57,8 +57,7 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
|
||||||
static Value extractOne(ConversionPatternRewriter &rewriter,
|
static Value extractOne(ConversionPatternRewriter &rewriter,
|
||||||
LLVMTypeConverter &typeConverter, Location loc,
|
LLVMTypeConverter &typeConverter, Location loc,
|
||||||
Value val, Type llvmType, int64_t rank, int64_t pos) {
|
Value val, Type llvmType, int64_t rank, int64_t pos) {
|
||||||
assert(rank > 0 && "0-D vector corner case should have been handled already");
|
if (rank <= 1) {
|
||||||
if (rank == 1) {
|
|
||||||
auto idxType = rewriter.getIndexType();
|
auto idxType = rewriter.getIndexType();
|
||||||
auto constant = rewriter.create<LLVM::ConstantOp>(
|
auto constant = rewriter.create<LLVM::ConstantOp>(
|
||||||
loc, typeConverter.convertType(idxType),
|
loc, typeConverter.convertType(idxType),
|
||||||
|
@ -987,7 +986,8 @@ public:
|
||||||
|
|
||||||
// Unroll vector into elementary print calls.
|
// Unroll vector into elementary print calls.
|
||||||
int64_t rank = vectorType ? vectorType.getRank() : 0;
|
int64_t rank = vectorType ? vectorType.getRank() : 0;
|
||||||
emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
|
Type type = vectorType ? vectorType : eltType;
|
||||||
|
emitRanks(rewriter, printOp, adaptor.source(), type, printer, rank,
|
||||||
conversion);
|
conversion);
|
||||||
emitCall(rewriter, printOp->getLoc(),
|
emitCall(rewriter, printOp->getLoc(),
|
||||||
LLVM::lookupOrCreatePrintNewlineFn(
|
LLVM::lookupOrCreatePrintNewlineFn(
|
||||||
|
@ -1006,10 +1006,12 @@ private:
|
||||||
};
|
};
|
||||||
|
|
||||||
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
|
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
|
||||||
Value value, VectorType vectorType, Operation *printer,
|
Value value, Type type, Operation *printer, int64_t rank,
|
||||||
int64_t rank, PrintConversion conversion) const {
|
PrintConversion conversion) const {
|
||||||
|
VectorType vectorType = type.dyn_cast<VectorType>();
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
if (rank == 0) {
|
if (!vectorType) {
|
||||||
|
assert(rank == 0 && "The scalar case expects rank == 0");
|
||||||
switch (conversion) {
|
switch (conversion) {
|
||||||
case PrintConversion::ZeroExt64:
|
case PrintConversion::ZeroExt64:
|
||||||
value = rewriter.create<arith::ExtUIOp>(
|
value = rewriter.create<arith::ExtUIOp>(
|
||||||
|
@ -1030,12 +1032,29 @@ private:
|
||||||
LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
|
LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
|
||||||
Operation *printComma =
|
Operation *printComma =
|
||||||
LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
|
LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
|
||||||
|
|
||||||
|
if (rank <= 1) {
|
||||||
|
auto reducedType = vectorType.getElementType();
|
||||||
|
auto llvmType = typeConverter->convertType(reducedType);
|
||||||
|
int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
|
||||||
|
for (int64_t d = 0; d < dim; ++d) {
|
||||||
|
Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
|
||||||
|
llvmType, /*rank=*/0, /*pos=*/d);
|
||||||
|
emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
|
||||||
|
conversion);
|
||||||
|
if (d != dim - 1)
|
||||||
|
emitCall(rewriter, loc, printComma);
|
||||||
|
}
|
||||||
|
emitCall(
|
||||||
|
rewriter, loc,
|
||||||
|
LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
int64_t dim = vectorType.getDimSize(0);
|
int64_t dim = vectorType.getDimSize(0);
|
||||||
for (int64_t d = 0; d < dim; ++d) {
|
for (int64_t d = 0; d < dim; ++d) {
|
||||||
auto reducedType =
|
auto reducedType = reducedVectorTypeFront(vectorType);
|
||||||
rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
|
auto llvmType = typeConverter->convertType(reducedType);
|
||||||
auto llvmType = typeConverter->convertType(
|
|
||||||
rank > 1 ? reducedType : vectorType.getElementType());
|
|
||||||
Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
|
Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
|
||||||
llvmType, rank, d);
|
llvmType, rank, d);
|
||||||
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
|
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
|
||||||
|
|
|
@ -832,6 +832,23 @@ func @vector_print_scalar_f64(%arg0: f64) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @vector_print_vector_0d(%arg0: vector<f32>) {
|
||||||
|
vector.print %arg0 : vector<f32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: @vector_print_vector_0d(
|
||||||
|
// CHECK-SAME: %[[A:.*]]: vector<f32>)
|
||||||
|
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
|
||||||
|
// CHECK: llvm.call @printOpen() : () -> ()
|
||||||
|
// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : index) : i64
|
||||||
|
// CHECK: %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<1xf32>
|
||||||
|
// CHECK: llvm.call @printF32(%[[T2]]) : (f32) -> ()
|
||||||
|
// CHECK: llvm.call @printClose() : () -> ()
|
||||||
|
// CHECK: llvm.call @printNewline() : () -> ()
|
||||||
|
// CHECK: return
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @vector_print_vector(%arg0: vector<2x2xf32>) {
|
func @vector_print_vector(%arg0: vector<2x2xf32>) {
|
||||||
vector.print %arg0 : vector<2x2xf32>
|
vector.print %arg0 : vector<2x2xf32>
|
||||||
return
|
return
|
||||||
|
|
|
@ -15,10 +15,20 @@ func @insert_element_0d(%a: f32, %b: vector<f32>) -> (vector<f32>) {
|
||||||
return %1: vector<f32>
|
return %1: vector<f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @print_vector_0d(%a: vector<f32>) {
|
||||||
|
// CHECK: ( 42 )
|
||||||
|
vector.print %a: vector<f32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func @entry() {
|
func @entry() {
|
||||||
%0 = arith.constant 42.0 : f32
|
%0 = arith.constant 42.0 : f32
|
||||||
%1 = arith.constant dense<0.0> : vector<f32>
|
%1 = arith.constant dense<0.0> : vector<f32>
|
||||||
%2 = call @insert_element_0d(%0, %1) : (f32, vector<f32>) -> (vector<f32>)
|
%2 = call @insert_element_0d(%0, %1) : (f32, vector<f32>) -> (vector<f32>)
|
||||||
call @extract_element_0d(%2) : (vector<f32>) -> ()
|
call @extract_element_0d(%2) : (vector<f32>) -> ()
|
||||||
|
|
||||||
|
%3 = arith.constant dense<42.0> : vector<f32>
|
||||||
|
call @print_vector_0d(%3) : (vector<f32>) -> ()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue