[MLIR] MemRef Normalization for Dialects

When dealing with dialects that will results in function calls to
external libraries, it is important to be able to handle maps as some
dialects may require mapped data.  Before this patch, the detection of
whether normalization can apply or not, operations are compared to an
explicit list of operations (`alloc`, `dealloc`, `return`) or to the
presence of specific operation interfaces (`AffineReadOpInterface`,
`AffineWriteOpInterface`, `AffineDMAStartOp`, or `AffineDMAWaitOp`).

This patch add a trait, `MemRefsNormalizable` to determine if an
operation can have its `memrefs` normalized.

This trait can be used in turn by dialects to assert that such
operations are compatible with normalization of `memrefs` with
nontrivial memory layout specification. An example is given in the
literal tests.

Differential Revision: https://reviews.llvm.org/D86236
This commit is contained in:
Alexandre E. Eichenberger 2020-08-27 10:47:33 +05:30 committed by Uday Bondhugula
parent b5924a8e27
commit a14a2805b0
10 changed files with 114 additions and 21 deletions

View File

@ -247,6 +247,18 @@ foo.region_op {
This trait is an important structural property of the IR, and enables operations
to have [passes](PassManagement.md) scheduled under them.
### MemRefsNormalizable
* `OpTrait::MemRefsNormalizable` -- `MemRefsNormalizable`
This trait is used to flag operations that can accommodate `MemRefs` with
non-identity memory-layout specifications. This trait indicates that the
normalization of memory layout can be performed for such operations.
`MemRefs` normalization consists of replacing an original memory reference
with layout specifications to an equivalent memory reference where
the specified memory layout is applied by rewritting accesses and types
associated with that memory reference.
### Single Block with Implicit Terminator
* `OpTrait::SingleBlockImplicitTerminator<typename TerminatorOpType>` :

View File

@ -80,8 +80,9 @@ bool isTopLevelValue(Value value);
// multiple stride levels (possibly using AffineMaps to specify multiple levels
// of striding).
// TODO: Consider replacing src/dst memref indices with view memrefs.
class AffineDmaStartOp : public Op<AffineDmaStartOp, OpTrait::VariadicOperands,
OpTrait::ZeroResult> {
class AffineDmaStartOp
: public Op<AffineDmaStartOp, OpTrait::MemRefsNormalizable,
OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
using Op::Op;
@ -268,8 +269,9 @@ public:
// ...
// affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2>
//
class AffineDmaWaitOp : public Op<AffineDmaWaitOp, OpTrait::VariadicOperands,
OpTrait::ZeroResult> {
class AffineDmaWaitOp
: public Op<AffineDmaWaitOp, OpTrait::MemRefsNormalizable,
OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
using Op::Op;

View File

@ -405,7 +405,8 @@ def AffineIfOp : Affine_Op<"if",
class AffineLoadOpBase<string mnemonic, list<OpTrait> traits = []> :
Affine_Op<mnemonic, !listconcat(traits,
[DeclareOpInterfaceMethods<AffineReadOpInterface>])> {
[DeclareOpInterfaceMethods<AffineReadOpInterface>,
MemRefsNormalizable])> {
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$memref,
Variadic<Index>:$indices);
@ -732,7 +733,8 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> {
class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
Affine_Op<mnemonic, !listconcat(traits,
[DeclareOpInterfaceMethods<AffineWriteOpInterface>])> {
[DeclareOpInterfaceMethods<AffineWriteOpInterface>,
MemRefsNormalizable])> {
code extraClassDeclarationBase = [{
/// Returns the operand index of the value to be stored.
unsigned getStoredValOperandIndex() { return 0; }

View File

@ -658,7 +658,7 @@ def BranchOp : Std_Op<"br",
// CallOp
//===----------------------------------------------------------------------===//
def CallOp : Std_Op<"call", [CallOpInterface]> {
def CallOp : Std_Op<"call", [CallOpInterface, MemRefsNormalizable]> {
let summary = "call operation";
let description = [{
The `call` operation represents a direct call to a function that is within
@ -1388,7 +1388,8 @@ def SinOp : FloatUnaryOp<"sin"> {
// DeallocOp
//===----------------------------------------------------------------------===//
def DeallocOp : Std_Op<"dealloc", [MemoryEffects<[MemFree]>]> {
def DeallocOp : Std_Op<"dealloc",
[MemoryEffects<[MemFree]>, MemRefsNormalizable]> {
let summary = "memory deallocation operation";
let description = [{
The `dealloc` operation frees the region of memory referenced by a memref
@ -2144,8 +2145,8 @@ def RemFOp : FloatArithmeticOp<"remf"> {
// ReturnOp
//===----------------------------------------------------------------------===//
def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">, ReturnLike,
Terminator]> {
def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
MemRefsNormalizable, ReturnLike, Terminator]> {
let summary = "return operation";
let description = [{
The `return` operation represents a return operation within a function.

View File

@ -1698,6 +1698,9 @@ def SameOperandsAndResultElementType :
NativeOpTrait<"SameOperandsAndResultElementType">;
// Op is a terminator.
def Terminator : NativeOpTrait<"IsTerminator">;
// Op can be safely normalized in the presence of MemRefs with
// non-identity maps.
def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">;
// Op's regions have a single block with the specified terminator.
class SingleBlockImplicitTerminator<string op>

View File

@ -1212,6 +1212,20 @@ struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
}
};
/// This trait is used to flag operations that can accommodate MemRefs with
/// non-identity memory-layout specifications. This trait indicates that the
/// normalization of memory layout can be performed for such operations.
/// MemRefs normalization consists of replacing an original memory reference
/// with layout specifications to an equivalent memory reference where the
/// specified memory layout is applied by rewritting accesses and types
/// associated with that memory reference.
// TODO: Right now, the operands of an operation are either all normalizable,
// or not. In the future, we may want to allow some of the operands to be
// normalizable.
template <typename ConcrentType>
struct MemRefsNormalizable
: public TraitBase<ConcrentType, MemRefsNormalizable> {};
} // end namespace OpTrait
//===----------------------------------------------------------------------===//

View File

@ -106,23 +106,15 @@ void NormalizeMemRefs::runOnOperation() {
normalizeFuncOpMemRefs(funcOp, moduleOp);
}
/// Return true if this operation dereferences one or more memref's.
/// TODO: Temporary utility, will be replaced when this is modeled through
/// side-effects/op traits.
static bool isMemRefDereferencingOp(Operation &op) {
return isa<AffineReadOpInterface, AffineWriteOpInterface, AffineDmaStartOp,
AffineDmaWaitOp>(op);
}
/// Check whether all the uses of oldMemRef are either dereferencing uses or the
/// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints
/// are satisfied will the value become a candidate for replacement.
/// TODO: Extend this for DimOps.
static bool isMemRefNormalizable(Value::user_range opUsers) {
if (llvm::any_of(opUsers, [](Operation *op) {
if (isMemRefDereferencingOp(*op))
if (op->hasTrait<OpTrait::MemRefsNormalizable>())
return false;
return !isa<DeallocOp, CallOp, ReturnOp>(*op);
return true;
}))
return false;
return true;

View File

@ -279,7 +279,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(
// Currently we support the following non-dereferencing ops to be a
// candidate for replacement: Dealloc, CallOp and ReturnOp.
// TODO: Add support for other kinds of ops.
if (!isa<DeallocOp, CallOp, ReturnOp>(*op))
if (!op->hasTrait<OpTrait::MemRefsNormalizable>())
return failure();
}

View File

@ -0,0 +1,57 @@
// RUN: mlir-opt -normalize-memrefs %s | FileCheck %s
// For all these cases, we test if MemRefs Normalization works with the test
// operations.
// * test.op_norm: this operation has the MemRefsNormalizable attribute. The tests
// that include this operation are constructed so that the normalization should
// happen.
// * test_op_nonnorm: this operation does not have the MemRefsNormalization
// attribute. The tests that include this operation are contructed so that the
// normalization should not happen.
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 64, d2 mod 32, d3 mod 64)>
// Test with op_norm and maps in arguments and in the operations in the function.
// CHECK-LABEL: test_norm
// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x1x1x32x64xf32>)
func @test_norm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
%0 = alloc() : memref<1x16x14x14xf32, #map0>
"test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
dealloc %0 : memref<1x16x14x14xf32, #map0>
// CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x64xf32>
// CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> ()
// CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
return
}
// Same test with op_nonnorm, with maps in the argmentets and the operations in the function.
// CHECK-LABEL: test_nonnorm
// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x14x14xf32, #map0>)
func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
%0 = alloc() : memref<1x16x14x14xf32, #map0>
"test.op_nonnorm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
dealloc %0 : memref<1x16x14x14xf32, #map0>
// CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x14x14xf32, #map0>
// CHECK: "test.op_nonnorm"(%[[ARG0]], %[[v0]]) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
// CHECK: dealloc %[[v0]] : memref<1x16x14x14xf32, #map0>
return
}
// Test with op_norm, with maps in the operations in the function.
// CHECK-LABEL: test_norm_mix
// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x1x1x32x64xf32>
func @test_norm_mix(%arg0 : memref<1x16x1x1x32x64xf32>) -> () {
%0 = alloc() : memref<1x16x14x14xf32, #map0>
"test.op_norm"(%arg0, %0) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32, #map0>) -> ()
dealloc %0 : memref<1x16x14x14xf32, #map0>
// CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x64xf32>
// CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> ()
// CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
return
}

View File

@ -618,6 +618,16 @@ def OpM : TEST_Op<"op_m"> {
let arguments = (ins I32, OptionalAttr<I32Attr>:$optional_attr);
let results = (outs I32);
}
// Test for memrefs normalization of an op with normalizable memrefs.
def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
}
// Test for memrefs normalization of an op without normalizable memrefs.
def OpNonNorm : TEST_Op<"op_nonnorm"> {
let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
}
// Pattern add the argument plus a increasing static number hidden in
// OpMTest function. That value is set into the optional argument.
// That way, we will know if operations is called once or twice.