[MLIR][LLVM] Add simple folders for bitcast/addrspacecast/gep

Add 5 simple folders
* bitcast(x : T0, T0) -> x
* addrcast(x : T0, T0) -> x
* bitcast(bitcast(x : T0, T1), T0) -> x
* addrcast(addrcast(x : T0, T1), T0) -> x
* gep %x:T, 0 -> %x:T

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D116715
This commit is contained in:
William S. Moses 2022-01-05 20:34:01 -05:00
parent 75ea6b4319
commit 358d020017
3 changed files with 99 additions and 2 deletions

View File

@ -327,6 +327,7 @@ def LLVM_GEPOp
let assemblyFormat = [{ let assemblyFormat = [{
$base `[` $indices `]` attr-dict `:` functional-type(operands, results) $base `[` $indices `]` attr-dict `:` functional-type(operands, results)
}]; }];
let hasFolder = 1;
} }
def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes { def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
@ -398,10 +399,14 @@ class LLVM_CastOp<string mnemonic, string builderFunc, Type type,
let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
} }
def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast", def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast",
LLVM_AnyNonAggregate, LLVM_AnyNonAggregate>; LLVM_AnyNonAggregate, LLVM_AnyNonAggregate> {
let hasFolder = 1;
}
def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "CreateAddrSpaceCast", def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "CreateAddrSpaceCast",
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>, LLVM_ScalarOrVectorOf<LLVM_AnyPointer>,
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>>; LLVM_ScalarOrVectorOf<LLVM_AnyPointer>> {
let hasFolder = 1;
}
def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr", def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr",
LLVM_ScalarOrVectorOf<AnyInteger>, LLVM_ScalarOrVectorOf<AnyInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>>; LLVM_ScalarOrVectorOf<LLVM_AnyPointer>>;

View File

@ -19,6 +19,7 @@
#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/TypeSwitch.h"
@ -2259,6 +2260,48 @@ static LogicalResult verify(FenceOp &op) {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// Folder for LLVM::BitcastOp
//===----------------------------------------------------------------------===//
OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) {
// bitcast(x : T0, T0) -> x
if (getArg().getType() == getType())
return getArg();
// bitcast(bitcast(x : T0, T1), T0) -> x
if (auto prev = getArg().getDefiningOp<BitcastOp>())
if (prev.getArg().getType() == getType())
return prev.getArg();
return {};
}
//===----------------------------------------------------------------------===//
// Folder for LLVM::AddrSpaceCastOp
//===----------------------------------------------------------------------===//
OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) {
// addrcast(x : T0, T0) -> x
if (getArg().getType() == getType())
return getArg();
// addrcast(addrcast(x : T0, T1), T0) -> x
if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>())
if (prev.getArg().getType() == getType())
return prev.getArg();
return {};
}
//===----------------------------------------------------------------------===//
// Folder for LLVM::GEPOp
//===----------------------------------------------------------------------===//
OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) {
// gep %x:T, 0 -> %x
if (getBase().getType() == getType() && getIndices().size() == 1 &&
matchPattern(getIndices()[0], m_Zero()))
return getBase();
return {};
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// LLVMDialect initialization, type parsing, and registration. // LLVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -38,3 +38,52 @@ llvm.func @no_fold_extractvalue(%arr: !llvm.array<4xf32>) -> f32 {
llvm.return %3 : f32 llvm.return %3 : f32
} }
// -----
// CHECK-LABEL: fold_bitcast
// CHECK-SAME: %[[a0:arg[0-9]+]]
// CHECK-NEXT: llvm.return %[[a0]]
llvm.func @fold_bitcast(%x : !llvm.ptr<i8>) -> !llvm.ptr<i8> {
%c = llvm.bitcast %x : !llvm.ptr<i8> to !llvm.ptr<i8>
llvm.return %c : !llvm.ptr<i8>
}
// CHECK-LABEL: fold_bitcast2
// CHECK-SAME: %[[a0:arg[0-9]+]]
// CHECK-NEXT: llvm.return %[[a0]]
llvm.func @fold_bitcast2(%x : !llvm.ptr<i8>) -> !llvm.ptr<i8> {
%c = llvm.bitcast %x : !llvm.ptr<i8> to !llvm.ptr<i32>
%d = llvm.bitcast %c : !llvm.ptr<i32> to !llvm.ptr<i8>
llvm.return %d : !llvm.ptr<i8>
}
// -----
// CHECK-LABEL: fold_addrcast
// CHECK-SAME: %[[a0:arg[0-9]+]]
// CHECK-NEXT: llvm.return %[[a0]]
llvm.func @fold_addrcast(%x : !llvm.ptr<i8>) -> !llvm.ptr<i8> {
%c = llvm.addrspacecast %x : !llvm.ptr<i8> to !llvm.ptr<i8>
llvm.return %c : !llvm.ptr<i8>
}
// CHECK-LABEL: fold_addrcast2
// CHECK-SAME: %[[a0:arg[0-9]+]]
// CHECK-NEXT: llvm.return %[[a0]]
llvm.func @fold_addrcast2(%x : !llvm.ptr<i8>) -> !llvm.ptr<i8> {
%c = llvm.addrspacecast %x : !llvm.ptr<i8> to !llvm.ptr<i32, 5>
%d = llvm.addrspacecast %c : !llvm.ptr<i32, 5> to !llvm.ptr<i8>
llvm.return %d : !llvm.ptr<i8>
}
// -----
// CHECK-LABEL: fold_gep
// CHECK-SAME: %[[a0:arg[0-9]+]]
// CHECK-NEXT: llvm.return %[[a0]]
llvm.func @fold_gep(%x : !llvm.ptr<i8>) -> !llvm.ptr<i8> {
%c0 = arith.constant 0 : i32
%c = llvm.getelementptr %x[%c0] : (!llvm.ptr<i8>, i32) -> !llvm.ptr<i8>
llvm.return %c : !llvm.ptr<i8>
}