forked from OSchip/llvm-project
[VectorOps] Add vector.print definition, with lowering support
Examples: vector.print %f : f32 vector.print %x : vector<4xf32> vector.print %y : vector<3x4xf32> vector.print %z : vector<2x3x4xf32> LLVM lowering replaces these with fully unrolled calls into a small runtime support library that provides some basic printing operations (single value, opening closing bracket, comma, newline). PiperOrigin-RevId: 286230325
This commit is contained in:
parent
c169852fc5
commit
d9b500d3bb
|
@ -987,4 +987,36 @@ def Vector_TupleGetOp :
|
|||
}];
|
||||
}
|
||||
|
||||
def Vector_PrintOp :
|
||||
Vector_Op<"print", []>, Arguments<(ins AnyType:$source)> {
|
||||
let summary = "print operation (for testing and debugging)";
|
||||
let description = [{
|
||||
Prints the source vector (or scalar) to stdout in human readable
|
||||
format (for testing and debugging). No return value.
|
||||
|
||||
Examples:
|
||||
```
|
||||
%0 = constant 0.0 : f32
|
||||
%1 = vector.broadcast %0 : f32 to vector<4xf32>
|
||||
vector.print %1 : vector<4xf32>
|
||||
|
||||
when lowered to LLVM, the vector print is unrolled into
|
||||
elementary printing method calls that at runtime will yield
|
||||
|
||||
( 0.0, 0.0, 0.0, 0.0 )
|
||||
|
||||
on stdout when linked with a small runtime support library,
|
||||
which only needs to provide a few printing methods (single
|
||||
value for all data types, opening/closing bracket, comma,
|
||||
newline).
|
||||
```
|
||||
}];
|
||||
let verifier = ?;
|
||||
let extraClassDeclaration = [{
|
||||
Type getPrintType() {
|
||||
return source()->getType();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // VECTOR_OPS
|
||||
|
|
|
@ -612,14 +612,136 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class VectorPrintOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit VectorPrintOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: LLVMOpLowering(vector::PrintOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
// Proof-of-concept lowering implementation that relies on a small
|
||||
// runtime support library, which only needs to provide a few
|
||||
// printing methods (single value for all data types, opening/closing
|
||||
// bracket, comma, newline). The lowering fully unrolls a vector
|
||||
// in terms of these elementary printing operations. The advantage
|
||||
// of this approach is that the library can remain unaware of all
|
||||
// low-level implementation details of vectors while still supporting
|
||||
// output of any shaped and dimensioned vector. Due to full unrolling,
|
||||
// this approach is less suited for very large vectors though.
|
||||
//
|
||||
// TODO(ajcbik): rely solely on libc in future? something else?
|
||||
//
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto printOp = cast<vector::PrintOp>(op);
|
||||
auto adaptor = vector::PrintOpOperandAdaptor(operands);
|
||||
Type printType = printOp.getPrintType();
|
||||
|
||||
if (lowering.convertType(printType) == nullptr)
|
||||
return matchFailure();
|
||||
|
||||
// Make sure element type has runtime support (currently just Float/Double).
|
||||
VectorType vectorType = printType.dyn_cast<VectorType>();
|
||||
Type eltType = vectorType ? vectorType.getElementType() : printType;
|
||||
int64_t rank = vectorType ? vectorType.getRank() : 0;
|
||||
Operation *printer;
|
||||
if (eltType.isF32())
|
||||
printer = getPrintFloat(op);
|
||||
else if (eltType.isF64())
|
||||
printer = getPrintDouble(op);
|
||||
else
|
||||
return matchFailure();
|
||||
|
||||
// Unroll vector into elementary print calls.
|
||||
emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
|
||||
emitCall(rewriter, op->getLoc(), getPrintNewline(op));
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
private:
|
||||
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value *value, VectorType vectorType, Operation *printer,
|
||||
int64_t rank) const {
|
||||
Location loc = op->getLoc();
|
||||
if (rank == 0) {
|
||||
emitCall(rewriter, loc, printer, value);
|
||||
return;
|
||||
}
|
||||
|
||||
emitCall(rewriter, loc, getPrintOpen(op));
|
||||
Operation *printComma = getPrintComma(op);
|
||||
int64_t dim = vectorType.getDimSize(0);
|
||||
for (int64_t d = 0; d < dim; ++d) {
|
||||
auto reducedType =
|
||||
rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
|
||||
auto llvmType = lowering.convertType(
|
||||
rank > 1 ? reducedType : vectorType.getElementType());
|
||||
Value *nestedVal =
|
||||
extractOne(rewriter, lowering, loc, value, llvmType, rank, d);
|
||||
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
|
||||
if (d != dim - 1)
|
||||
emitCall(rewriter, loc, printComma);
|
||||
}
|
||||
emitCall(rewriter, loc, getPrintClose(op));
|
||||
}
|
||||
|
||||
// Helper to emit a call.
|
||||
static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Operation *ref, ValueRange params = ValueRange()) {
|
||||
rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
|
||||
rewriter.getSymbolRefAttr(ref), params);
|
||||
}
|
||||
|
||||
// Helper for printer method declaration (first hit) and lookup.
|
||||
static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect,
|
||||
StringRef name, ArrayRef<LLVM::LLVMType> params) {
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
|
||||
if (func)
|
||||
return func;
|
||||
OpBuilder moduleBuilder(module.getBodyRegion());
|
||||
return moduleBuilder.create<LLVM::LLVMFuncOp>(
|
||||
op->getLoc(), name,
|
||||
LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect),
|
||||
params, /*isVarArg=*/false));
|
||||
}
|
||||
|
||||
// Helpers for method names.
|
||||
Operation *getPrintFloat(Operation *op) const {
|
||||
LLVM::LLVMDialect *dialect = lowering.getDialect();
|
||||
return getPrint(op, dialect, "print_f32",
|
||||
LLVM::LLVMType::getFloatTy(dialect));
|
||||
}
|
||||
Operation *getPrintDouble(Operation *op) const {
|
||||
LLVM::LLVMDialect *dialect = lowering.getDialect();
|
||||
return getPrint(op, dialect, "print_f64",
|
||||
LLVM::LLVMType::getDoubleTy(dialect));
|
||||
}
|
||||
Operation *getPrintOpen(Operation *op) const {
|
||||
return getPrint(op, lowering.getDialect(), "print_open", {});
|
||||
}
|
||||
Operation *getPrintClose(Operation *op) const {
|
||||
return getPrint(op, lowering.getDialect(), "print_close", {});
|
||||
}
|
||||
Operation *getPrintComma(Operation *op) const {
|
||||
return getPrint(op, lowering.getDialect(), "print_comma", {});
|
||||
}
|
||||
Operation *getPrintNewline(Operation *op) const {
|
||||
return getPrint(op, lowering.getDialect(), "print_newline", {});
|
||||
}
|
||||
};
|
||||
|
||||
/// Populate the given list with patterns that convert from Vector to LLVM.
|
||||
void mlir::populateVectorToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
|
||||
VectorExtractElementOpConversion, VectorExtractOpConversion,
|
||||
VectorInsertElementOpConversion, VectorInsertOpConversion,
|
||||
VectorOuterProductOpConversion, VectorTypeCastOpConversion>(
|
||||
converter.getDialect()->getContext(), converter);
|
||||
VectorOuterProductOpConversion, VectorTypeCastOpConversion,
|
||||
VectorPrintOpConversion>(converter.getDialect()->getContext(),
|
||||
converter);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -1587,6 +1587,23 @@ static LogicalResult verify(CreateMaskOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrintOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult parsePrintOp(OpAsmParser &parser, OperationState &result) {
|
||||
OpAsmParser::OperandType source;
|
||||
Type sourceType;
|
||||
return failure(parser.parseOperand(source) ||
|
||||
parser.parseColonType(sourceType) ||
|
||||
parser.resolveOperand(source, sourceType, result.operands));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, PrintOp op) {
|
||||
p << op.getOperationName() << ' ' << *op.source() << " : "
|
||||
<< op.getPrintType();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
|
||||
|
|
|
@ -385,3 +385,41 @@ func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
|
|||
// CHECK: llvm.insertvalue %[[alignedBit]], {{.*}}[1] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
|
||||
// CHECK: llvm.mlir.constant(0 : index
|
||||
// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
|
||||
|
||||
func @vector_print_scalar(%arg0: f32) {
|
||||
vector.print %arg0 : f32
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: vector_print_scalar(%arg0: !llvm.float)
|
||||
// CHECK: llvm.call @print_f32(%arg0) : (!llvm.float) -> ()
|
||||
// CHECK: llvm.call @print_newline() : () -> ()
|
||||
|
||||
func @vector_print_vector(%arg0: vector<2x2xf32>) {
|
||||
vector.print %arg0 : vector<2x2xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: vector_print_vector(%arg0: !llvm<"[2 x <2 x float>]">)
|
||||
// CHECK: llvm.call @print_open() : () -> ()
|
||||
// CHECK: %[[x0:.*]] = llvm.extractvalue %arg0[0] : !llvm<"[2 x <2 x float>]">
|
||||
// CHECK: llvm.call @print_open() : () -> ()
|
||||
// CHECK: %[[x1:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK: %[[x2:.*]] = llvm.extractelement %[[x0]][%[[x1]] : !llvm.i64] : !llvm<"<2 x float>">
|
||||
// CHECK: llvm.call @print_f32(%[[x2]]) : (!llvm.float) -> ()
|
||||
// CHECK: llvm.call @print_comma() : () -> ()
|
||||
// CHECK: %[[x3:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: %[[x4:.*]] = llvm.extractelement %[[x0]][%[[x3]] : !llvm.i64] : !llvm<"<2 x float>">
|
||||
// CHECK: llvm.call @print_f32(%[[x4]]) : (!llvm.float) -> ()
|
||||
// CHECK: llvm.call @print_close() : () -> ()
|
||||
// CHECK: llvm.call @print_comma() : () -> ()
|
||||
// CHECK: %[[x5:.*]] = llvm.extractvalue %arg0[1] : !llvm<"[2 x <2 x float>]">
|
||||
// CHECK: llvm.call @print_open() : () -> ()
|
||||
// CHECK: %[[x6:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK: %[[x7:.*]] = llvm.extractelement %[[x5]][%[[x6]] : !llvm.i64] : !llvm<"<2 x float>">
|
||||
// CHECK: llvm.call @print_f32(%[[x7]]) : (!llvm.float) -> ()
|
||||
// CHECK: llvm.call @print_comma() : () -> ()
|
||||
// CHECK: %[[x8:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: %[[x9:.*]] = llvm.extractelement %[[x5]][%[[x8]] : !llvm.i64] : !llvm<"<2 x float>">
|
||||
// CHECK: llvm.call @print_f32(%[[x9]]) : (!llvm.float) -> ()
|
||||
// CHECK: llvm.call @print_close() : () -> ()
|
||||
// CHECK: llvm.call @print_close() : () -> ()
|
||||
// CHECK: llvm.call @print_newline() : () -> ()
|
||||
|
|
|
@ -818,3 +818,11 @@ func @insert_slices_invalid_tuple_element_type(%arg0 : tuple<vector<2x2xf32>, ve
|
|||
: tuple<vector<2x2xf32>, vector<4x2xf32>> into vector<4x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @print_no_result(%arg0 : f32) -> i32 {
|
||||
// expected-error@+1 {{cannot name an operation with no results}}
|
||||
%0 = vector.print %arg0 : f32
|
||||
return %0
|
||||
}
|
||||
|
|
|
@ -198,3 +198,10 @@ func @insert_slices(%arg0 : tuple<vector<2x2xf32>, vector<2x2xf32>>)
|
|||
: tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<4x2xf32>
|
||||
return %0 : vector<4x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @vector_print
|
||||
func @vector_print(%arg0: vector<8x4xf32>) {
|
||||
// CHECK: vector.print %{{.*}} : vector<8x4xf32>
|
||||
vector.print %arg0 : vector<8x4xf32>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -285,4 +285,12 @@ print_memref_4d_f32(StridedMemRefType<float, 4> *M);
|
|||
extern "C" MLIR_RUNNER_UTILS_EXPORT void
|
||||
print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M);
|
||||
|
||||
// Small runtime support "lib" for vector.print lowering.
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void print_f32(float f);
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void print_f64(double d);
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void print_open();
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void print_close();
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void print_comma();
|
||||
extern "C" MLIR_RUNNER_UTILS_EXPORT void print_newline();
|
||||
|
||||
#endif // MLIR_CPU_RUNNER_MLIRUTILS_H_
|
||||
|
|
|
@ -63,3 +63,14 @@ extern "C" void print_memref_3d_f32(StridedMemRefType<float, 3> *M) {
|
|||
extern "C" void print_memref_4d_f32(StridedMemRefType<float, 4> *M) {
|
||||
impl::printMemRef(*M);
|
||||
}
|
||||
|
||||
// Small runtime support "lib" for vector.print lowering.
|
||||
// By providing elementary printing methods only, this
|
||||
// library can remain fully unaware of low-level implementation
|
||||
// details of our vectors.
|
||||
extern "C" void print_f32(float f) { std::cout << f; }
|
||||
extern "C" void print_f64(double d) { std::cout << d; }
|
||||
extern "C" void print_open() { std::cout << "( "; }
|
||||
extern "C" void print_close() { std::cout << " )"; }
|
||||
extern "C" void print_comma() { std::cout << ", "; }
|
||||
extern "C" void print_newline() { std::cout << "\n"; }
|
||||
|
|
Loading…
Reference in New Issue