From 0d65000e11777b8d2d6aa9f135753209593f2f00 Mon Sep 17 00:00:00 2001 From: Tim Shen Date: Fri, 21 Feb 2020 18:13:56 -0800 Subject: [PATCH] [MLIR] Add llvm.mlir.cast op for semantic preserving cast between dialect types. Summary: See discussion here: https://llvm.discourse.group/t/rfc-dialect-type-cast-op/538/11 Reviewers: ftynse Subscribers: bixia, sanjoy.google, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, llvm-commits Differential Revision: https://reviews.llvm.org/D75141 --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 19 +++++++++ .../StandardToLLVM/ConvertStandardToLLVM.cpp | 20 ++++++++++ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 39 +++++++++++++++++++ .../StandardToLLVM/convert-to-llvmir.mlir | 36 +++++++++++++++++ .../Conversion/StandardToLLVM/invalid.mlir | 35 ++++++++++++++++- 5 files changed, 147 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index a4d74bf70cbe..601bfbf68926 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -686,6 +686,25 @@ def LLVM_ConstantOp let assemblyFormat = "`(` $value `)` attr-dict `:` type($res)"; } +def LLVM_DialectCastOp : LLVM_Op<"mlir.cast", [NoSideEffect]>, + Results<(outs AnyType:$res)>, + Arguments<(ins AnyType:$in)> { + let summary = "Type cast between LLVM dialect and Standard."; + let description = [{ + llvm.mlir.cast op casts between Standard and LLVM dialects. It only changes + the dialect, but does not change compile-time or runtime semantics. + + Notice that index type is not supported, as it's Standard-specific. + + Example: + llvm.mlir.cast %v : f16 to llvm.half + llvm.mlir.cast %v : llvm.float to f32 + llvm.mlir.cast %v : !llvm<"<2 x float>"> to vector<2xf32> + }]; + let assemblyFormat = "$in attr-dict `:` type($in) `to` type($res)"; + let verifier = "return ::verify(*this);"; +} + // Operations that correspond to LLVM intrinsics. With MLIR operation set being // extendable, there is no reason to introduce a hard boundary between "core" // operations and intrinsics. However, we systematically prefix them with diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 72985ffc639b..96d7e82d6a32 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1807,6 +1807,24 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { } }; +struct DialectCastOpLowering + : public LLVMLegalizationPattern { + using LLVMLegalizationPattern::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto castOp = cast(op); + OperandAdaptor transformed(operands); + if (transformed.in().getType() != + typeConverter.convertType(castOp.getType())) { + return matchFailure(); + } + rewriter.replaceOp(op, transformed.in()); + return matchSuccess(); + } +}; + // A `dim` is converted to a constant for static sizes and to an access to the // size stored in the memref descriptor for dynamic sizes. struct DimOpLowering : public LLVMLegalizationPattern { @@ -2772,6 +2790,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns( CopySignOpLowering, CosOpLowering, ConstLLVMOpLowering, + DialectCastOpLowering, DivFOpLowering, ExpOpLowering, LogOpLowering, @@ -2988,6 +3007,7 @@ struct LLVMLoweringPass : public ModulePass { mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { this->addLegalDialect(); + this->addIllegalOp(); } std::unique_ptr> diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index c1773bbd8120..a8f1dd56e02e 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -890,6 +890,45 @@ static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) { p.printRegion(initializer, /*printEntryBlockArgs=*/false); } +//===----------------------------------------------------------------------===// +// Verifier for LLVM::DialectCastOp. +//===----------------------------------------------------------------------===// + +static LogicalResult verify(DialectCastOp op) { + auto verifyMLIRCastType = [&op](Type type) -> LogicalResult { + if (auto llvmType = type.dyn_cast()) { + if (llvmType.isVectorTy()) + llvmType = llvmType.getVectorElementType(); + if (llvmType.isIntegerTy() || llvmType.isHalfTy() || + llvmType.isFloatTy() || llvmType.isDoubleTy()) { + return success(); + } + return op.emitOpError("type must be non-index integer types, float " + "types, or vector of mentioned types."); + } + if (auto vectorType = type.dyn_cast()) { + if (vectorType.getShape().size() > 1) + return op.emitOpError("only 1-d vector is allowed"); + type = vectorType.getElementType(); + } + if (type.isSignlessIntOrFloat()) + return success(); + // Note that memrefs are not supported. We currently don't have a use case + // for it, but even if we do, there are challenges: + // * if we allow memrefs to cast from/to memref descriptors, then the + // semantics of the cast op depends on the implementation detail of the + // descriptor. + // * if we allow memrefs to cast from/to bare pointers, some users might + // alternatively want metadata that only present in the descriptor. + // + // TODO(timshen): re-evaluate the memref cast design when it's needed. + return op.emitOpError("type must be non-index integer types, float types, " + "or vector of mentioned types."); + }; + return failure(failed(verifyMLIRCastType(op.in().getType())) || + failed(verifyMLIRCastType(op.getType()))); +} + // Parses one of the keywords provided in the list `keywords` and returns the // position of the parsed keyword in the list. If none of the keywords from the // list is parsed, returns -1. diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir index 27c249372b15..68aeef8a2e1f 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -910,3 +910,39 @@ func @assume_alignment(%0 : memref<4x4xf16>) { assume_alignment %0, 16 : memref<4x4xf16> return } + +// ----- + +// CHECK-LABEL: func @mlir_cast_to_llvm +// CHECK-SAME: %[[ARG:.*]]: +func @mlir_cast_to_llvm(%0 : vector<2xf16>) -> !llvm<"<2 x half>"> { + %1 = llvm.mlir.cast %0 : vector<2xf16> to !llvm<"<2 x half>"> + // CHECK-NEXT: llvm.return %[[ARG]] + return %1 : !llvm<"<2 x half>"> +} + +// CHECK-LABEL: func @mlir_cast_from_llvm +// CHECK-SAME: %[[ARG:.*]]: +func @mlir_cast_from_llvm(%0 : !llvm<"<2 x half>">) -> vector<2xf16> { + %1 = llvm.mlir.cast %0 : !llvm<"<2 x half>"> to vector<2xf16> + // CHECK-NEXT: llvm.return %[[ARG]] + return %1 : vector<2xf16> +} + +// ----- + +// CHECK-LABEL: func @mlir_cast_to_llvm +// CHECK-SAME: %[[ARG:.*]]: +func @mlir_cast_to_llvm(%0 : f16) -> !llvm.half { + %1 = llvm.mlir.cast %0 : f16 to !llvm.half + // CHECK-NEXT: llvm.return %[[ARG]] + return %1 : !llvm.half +} + +// CHECK-LABEL: func @mlir_cast_from_llvm +// CHECK-SAME: %[[ARG:.*]]: +func @mlir_cast_from_llvm(%0 : !llvm.half) -> f16 { + %1 = llvm.mlir.cast %0 : !llvm.half to f16 + // CHECK-NEXT: llvm.return %[[ARG]] + return %1 : f16 +} diff --git a/mlir/test/Conversion/StandardToLLVM/invalid.mlir b/mlir/test/Conversion/StandardToLLVM/invalid.mlir index e0b1889e95c8..bb9c2728dcb8 100644 --- a/mlir/test/Conversion/StandardToLLVM/invalid.mlir +++ b/mlir/test/Conversion/StandardToLLVM/invalid.mlir @@ -1,13 +1,44 @@ -// RUN: mlir-opt %s -verify-diagnostics -split-input-file +// RUN: mlir-opt %s -convert-std-to-llvm -verify-diagnostics -split-input-file #map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> func @invalid_memref_cast(%arg0: memref) { %c1 = constant 1 : index %c0 = constant 0 : index - // expected-error@+1: 'std.memref_cast' op operand #0 must be unranked.memref of any type values or memref of any type values, + // expected-error@+1 {{'std.memref_cast' op operand #0 must be unranked.memref of any type values or memref of any type values, but got '!llvm<"{ double*, double*, i64, [2 x i64], [2 x i64] }">'}} %5 = memref_cast %arg0 : memref to memref %25 = std.subview %5[%c0, %c0][%c1, %c1][] : memref to memref return } +// ----- + +func @mlir_cast_to_llvm(%0 : index) -> !llvm.i64 { + // expected-error@+1 {{'llvm.mlir.cast' op type must be non-index integer types, float types, or vector of mentioned types}} + %1 = llvm.mlir.cast %0 : index to !llvm.i64 + return %1 : !llvm.i64 +} + +// ----- + +func @mlir_cast_from_llvm(%0 : !llvm.i64) -> index { + // expected-error@+1 {{'llvm.mlir.cast' op type must be non-index integer types, float types, or vector of mentioned types}} + %1 = llvm.mlir.cast %0 : !llvm.i64 to index + return %1 : index +} + +// ----- + +func @mlir_cast_to_llvm_int(%0 : i32) -> !llvm.i64 { + // expected-error@+1 {{failed to legalize operation 'llvm.mlir.cast' that was explicitly marked illegal}} + %1 = llvm.mlir.cast %0 : i32 to !llvm.i64 + return %1 : !llvm.i64 +} + +// ----- + +func @mlir_cast_to_llvm_vec(%0 : vector<1x1xf32>) -> !llvm<"<1 x float>"> { + // expected-error@+1 {{'llvm.mlir.cast' op only 1-d vector is allowed}} + %1 = llvm.mlir.cast %0 : vector<1x1xf32> to !llvm<"<1 x float>"> + return %1 : !llvm<"<1 x float>"> +}