forked from OSchip/llvm-project
[mlir][Linalg] Introduce folding patterns to remove certain MemRefCastOp
Summary: Canonicalization and folding patterns in StandardOps may interfere with the needs of Linalg. This revision introduces specific foldings for dynamic memrefs that can be proven to be static. Very concretely: Determines whether it is possible to fold it away in the parent Linalg op: ```mlir %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32> %2 = linalg.slice %1 ... : memref<?x?xf32> ... // or %1 = memref_cast %0 : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>> to memref<?x?xf32> linalg.generic(%1 ...) : memref<?x?xf32> ... ``` into ```mlir %2 = linalg.slice %0 ... : memref<8x16xf32> ... // or linalg.generic(%0 ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>> ``` Reviewers: ftynse, aartbik, jsetoain, tetuante, asaadaldien Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D73565
This commit is contained in:
parent
02adfb5155
commit
ea1e3369f7
|
@ -117,6 +117,8 @@ def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>,
|
|||
static StringRef getReassociationAttrName() { return "reassociation"; }
|
||||
MemRefType getViewType() { return view().getType().cast<MemRefType>(); }
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
|
||||
|
@ -188,6 +190,8 @@ def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
|
|||
return res;
|
||||
}
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
|
||||
|
@ -222,6 +226,8 @@ def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
|
|||
static StringRef getPermutationAttrName() { return "permutation"; }
|
||||
ShapedType getShapedType() { return view().getType().cast<ShapedType>(); }
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Linalg_YieldOp : Linalg_Op<"yield", [NativeOpTrait<"IsTerminator">]>,
|
||||
|
|
|
@ -270,6 +270,8 @@ def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> {
|
|||
}
|
||||
}];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
|
||||
|
@ -287,6 +289,8 @@ def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
|
|||
}
|
||||
}];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
|
||||
|
@ -302,6 +306,8 @@ def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
|
|||
StringAttr::get(getReductionIteratorTypeName(), ctx), ctx);
|
||||
}
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> {
|
||||
|
@ -319,6 +325,8 @@ def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> {
|
|||
return ArrayAttr::get(iters, ctx);
|
||||
}
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
|
||||
|
@ -337,6 +345,8 @@ def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
|
|||
return ArrayAttr::get(iters, ctx);
|
||||
}
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
|
||||
|
@ -406,7 +416,10 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
|
|||
.cast<IntegerAttr>().getValue().getSExtValue();
|
||||
}
|
||||
}];
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def LinalgOperand: Type<
|
||||
|
@ -583,7 +596,10 @@ def GenericOp : GenericOpBase<"generic"> {
|
|||
tensor SSA values are expected to be useful and will be added in the near
|
||||
future.
|
||||
}];
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
|
||||
|
@ -710,7 +726,10 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
|
|||
tensor SSA values are expected to be useful and will be added in the near
|
||||
future.
|
||||
}];
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
#endif // LINALG_STRUCTURED_OPS
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -31,6 +32,89 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
/// Determines whether it is possible to fold it away in the parent Linalg op:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
|
||||
/// %2 = linalg.slice %1 ... : memref<?x?xf32> ...
|
||||
/// // or
|
||||
/// %1 = memref_cast %0 : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
|
||||
/// to memref<?x?xf32>
|
||||
/// linalg.generic(%1 ...) : memref<?x?xf32> ...
|
||||
/// ```
|
||||
///
|
||||
/// into
|
||||
///
|
||||
/// ```mlir
|
||||
/// %2 = linalg.slice %0 ... : memref<8x16xf32> ...
|
||||
/// // or
|
||||
/// linalg.generic(%0 ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
|
||||
/// ```
|
||||
///
|
||||
static bool canFold(MemRefCastOp castOp) {
|
||||
MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
|
||||
MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
|
||||
|
||||
// If we don't have MemRefType as source and destination, bail out.
|
||||
if (!sourceType || !resultType)
|
||||
return false;
|
||||
|
||||
// If resultType has a map, it needs to be the same as the source type to
|
||||
// canonicalize.
|
||||
if (!resultType.getAffineMaps().empty() &&
|
||||
sourceType.getAffineMaps() != resultType.getAffineMaps())
|
||||
return false;
|
||||
|
||||
// Ensure that:
|
||||
// 1. source is static
|
||||
// 2. source and target have the same rank (will be extended when needed)
|
||||
// 3. if result is partially static, ensure sizes match.
|
||||
if (!sourceType.hasStaticShape() ||
|
||||
sourceType.getRank() != resultType.getRank())
|
||||
return false;
|
||||
|
||||
for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
|
||||
auto sourceSize = std::get<0>(it);
|
||||
auto resultSize = std::get<1>(it);
|
||||
if (ShapedType::isDynamic(resultSize))
|
||||
continue;
|
||||
if (sourceSize != resultSize)
|
||||
return false;
|
||||
}
|
||||
|
||||
// If source has a map, it can only canonicalize if it is the canonical
|
||||
// strided layout map.
|
||||
if (sourceType.getAffineMaps().empty())
|
||||
return true;
|
||||
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto res = getStridesAndOffset(sourceType, strides, offset);
|
||||
(void)res;
|
||||
assert(succeeded(res));
|
||||
auto stridedMap =
|
||||
makeStridedLinearLayoutMap(strides, offset, castOp.getContext());
|
||||
AffineMap sourceMap = sourceType.getAffineMaps().front();
|
||||
return sourceMap == stridedMap;
|
||||
}
|
||||
|
||||
/// This is a common class used for patterns of the form
|
||||
/// ```
|
||||
/// someop(memrefcast) -> someop
|
||||
/// ```
|
||||
/// It folds the source of any memref_cast into the root operation directly.
|
||||
static LogicalResult foldMemRefCast(Operation *op) {
|
||||
bool folded = false;
|
||||
for (OpOperand &operand : op->getOpOperands()) {
|
||||
auto castOp = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
|
||||
if (castOp && canFold(castOp)) {
|
||||
operand.set(castOp.getOperand());
|
||||
folded = true;
|
||||
}
|
||||
}
|
||||
return success(folded);
|
||||
}
|
||||
|
||||
///////////////////// Operations defined with Tablegen /////////////////////////
|
||||
// For such operations that do not correspond to library calls (i.e. defined in
|
||||
// LinalgOps.td), we define an overloaded `print` function and a
|
||||
|
@ -1077,3 +1161,54 @@ ArrayAttr mlir::linalg::MatmulOp::indexing_maps() {
|
|||
ArrayAttr mlir::linalg::MatvecOp::indexing_maps() {
|
||||
return getIndexingMaps(getOperation());
|
||||
}
|
||||
|
||||
// TODO(ntv, rriddle): Consider making all this boilerplate easy to autogenerate
|
||||
// with Tablegen. This seems a desirable property in the context of OpInterfaces
|
||||
// where a Linalg "named" op **isa** LinalgOp.
|
||||
LogicalResult ConvOp::fold(ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
LogicalResult CopyOp::fold(ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
LogicalResult DotOp::fold(ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
LogicalResult FillOp::fold(ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
LogicalResult GenericOp::fold(ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
LogicalResult IndexedGenericOp::fold(ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
LogicalResult MatvecOp::fold(ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
LogicalResult MatmulOp::fold(ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute>) {
|
||||
if (succeeded(foldMemRefCast(*this)))
|
||||
return getResult();
|
||||
return {};
|
||||
}
|
||||
OpFoldResult SliceOp::fold(ArrayRef<Attribute>) {
|
||||
if (succeeded(foldMemRefCast(*this)))
|
||||
return getResult();
|
||||
return {};
|
||||
}
|
||||
OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
|
||||
if (succeeded(foldMemRefCast(*this)))
|
||||
return getResult();
|
||||
return {};
|
||||
}
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
// RUN: mlir-opt %s -canonicalize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @memref_cast(
|
||||
func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c8 = constant 8 : index
|
||||
%c16 = constant 16 : index
|
||||
%1 = alloc (%b) : memref<?xi8>
|
||||
%2 = view %1[][] : memref<?xi8> to memref<16x16xf32>
|
||||
%3 = memref_cast %2 : memref<16x16xf32> to memref<?x?xf32>
|
||||
%r0 = linalg.range %c0:%c8:%c1 : !linalg.range
|
||||
|
||||
// CHECK: linalg.slice {{.*}} : memref<16x16xf32>, !linalg.range, !linalg.range, memref<?x?xf32>
|
||||
%4 = linalg.slice %3[%r0, %r0] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32>
|
||||
|
||||
// CHECK: linalg.matmul{{.*}}: memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>
|
||||
linalg.matmul(%3, %3, %3) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
|
||||
return %4: memref<?x?xf32>
|
||||
}
|
Loading…
Reference in New Issue