forked from OSchip/llvm-project
[mlir][transform] Add multi-buffering to the transform dialect
Add the plumbing necessary to call the memref dialect's multiBuffer function. This will allow separation between choosing which buffers to multi-buffer and the actual transform. Alter the multibuffer function to return the newly created allocation if multi-buffering succeeds. This is necessary to communicate with the transform dialect hooks what allocation multi-buffering created. Reviewed By: ftynse, nicolasvasilache Differential Revision: https://reviews.llvm.org/D133985
This commit is contained in:
parent
1e818cd8e2
commit
3f050f6ac4
|
@ -1,2 +1,3 @@
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(TransformOps)
|
||||
add_subdirectory(Transforms)
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
set(LLVM_TARGET_DEFINITIONS MemRefTransformOps.td)
|
||||
mlir_tablegen(MemRefTransformOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(MemRefTransformOps.cpp.inc -gen-op-defs)
|
||||
add_public_tablegen_target(MLIRMemRefTransformOpsIncGen)
|
||||
|
||||
add_mlir_doc(MemRefTransformOps MemRefTransformOps Dialects/ -gen-op-doc)
|
|
@ -0,0 +1,33 @@
|
|||
//===- MemRefTransformOps.h - MemRef transformation ops ---------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_MEMREF_TRANSFORMOPS_MEMREFTRANSFORMOPS_H
|
||||
#define MLIR_DIALECT_MEMREF_TRANSFORMOPS_MEMREFTRANSFORMOPS_H
|
||||
|
||||
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace memref {
|
||||
class AllocOp;
|
||||
} // namespace memref
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
|
||||
namespace memref {
|
||||
void registerTransformDialectExtension(DialectRegistry ®istry);
|
||||
} // namespace memref
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_MEMREF_TRANSFORMOPS_MEMREFTRANSFORMOPS_H
|
|
@ -0,0 +1,52 @@
|
|||
//===- MemRefTransformOps.td - MemRef transformation ops --*- tablegen -*--===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MEMREF_TRANSFORM_OPS
|
||||
#define MEMREF_TRANSFORM_OPS
|
||||
|
||||
include "mlir/Dialect/Transform/IR/TransformDialect.td"
|
||||
include "mlir/Dialect/Transform/IR/TransformEffects.td"
|
||||
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
|
||||
include "mlir/Dialect/PDL/IR/PDLTypes.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer",
|
||||
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
|
||||
TransformOpInterface, TransformEachOpTrait]> {
|
||||
let summary = "Multibuffers an allocation";
|
||||
let description = [{
|
||||
Transformation to do multi-buffering/array expansion to remove
|
||||
dependencies on the temporary allocation between consecutive loop
|
||||
iterations. This transform expands the size of an allocation by
|
||||
a given multiplicative factor and fixes up any users of the
|
||||
multibuffered allocation.
|
||||
|
||||
#### Return modes
|
||||
|
||||
This operation returns the new allocation if multi-buffering
|
||||
succeeds, and failure otherwise.
|
||||
}];
|
||||
|
||||
let arguments =
|
||||
(ins PDL_Operation:$target,
|
||||
ConfinedAttr<I64Attr, [IntPositive]>:$factor);
|
||||
|
||||
let results = (outs PDL_Operation:$transformed);
|
||||
|
||||
let assemblyFormat = "$target attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
memref::AllocOp target,
|
||||
::llvm::SmallVector<::mlir::Operation *> &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // MEMREF_TRANSFORM_OPS
|
|
@ -62,8 +62,8 @@ void populateSimplifyExtractStridedMetadataOpPatterns(
|
|||
|
||||
/// Transformation to do multi-buffering/array expansion to remove dependencies
|
||||
/// on the temporary allocation between consecutive loop iterations.
|
||||
/// It return success if the allocation was multi-buffered and returns failure()
|
||||
/// otherwise.
|
||||
/// It returns the new allocation if the original allocation was multi-buffered
|
||||
/// and returns failure() otherwise.
|
||||
/// Example:
|
||||
/// ```
|
||||
/// %0 = memref.alloc() : memref<4x128xf32>
|
||||
|
@ -85,7 +85,8 @@ void populateSimplifyExtractStridedMetadataOpPatterns(
|
|||
/// "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> ()
|
||||
/// }
|
||||
/// ```
|
||||
LogicalResult multiBuffer(memref::AllocOp allocOp, unsigned multiplier);
|
||||
FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
|
||||
unsigned multiplier);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Passes
|
||||
|
|
|
@ -41,6 +41,7 @@
|
|||
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
|
||||
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
|
||||
#include "mlir/Dialect/OpenACC/OpenACC.h"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
||||
|
@ -112,6 +113,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
|
|||
// Register all dialect extensions.
|
||||
bufferization::registerTransformDialectExtension(registry);
|
||||
linalg::registerTransformDialectExtension(registry);
|
||||
memref::registerTransformDialectExtension(registry);
|
||||
scf::registerTransformDialectExtension(registry);
|
||||
|
||||
// Register all external models.
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(TransformOps)
|
||||
add_subdirectory(Transforms)
|
||||
add_subdirectory(Utils)
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
add_mlir_dialect_library(MLIRMemRefTransformOps
|
||||
MemRefTransformOps.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef/TransformOps
|
||||
|
||||
DEPENDS
|
||||
MLIRMemRefTransformOpsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAffineDialect
|
||||
MLIRArithmeticDialect
|
||||
MLIRIR
|
||||
MLIRPDLDialect
|
||||
MLIRMemRefDialect
|
||||
MLIRMemRefTransforms
|
||||
MLIRTransformDialect
|
||||
)
|
|
@ -0,0 +1,69 @@
|
|||
//===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/PDL/IR/PDL.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MemRefMultiBufferOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::MemRefMultiBufferOp::applyToOne(memref::AllocOp target,
|
||||
SmallVector<Operation *> &results,
|
||||
transform::TransformState &state) {
|
||||
auto newBuffer = memref::multiBuffer(target, getFactor());
|
||||
if (failed(newBuffer)) {
|
||||
Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note);
|
||||
diag << "op failed to multibuffer";
|
||||
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
|
||||
}
|
||||
|
||||
results.push_back(newBuffer.value());
|
||||
return DiagnosedSilenceableFailure(success());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Transform op registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class MemRefTransformDialectExtension
|
||||
: public transform::TransformDialectExtension<
|
||||
MemRefTransformDialectExtension> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
void init() {
|
||||
declareDependentDialect<pdl::PDLDialect>();
|
||||
declareGeneratedDialect<AffineDialect>();
|
||||
declareGeneratedDialect<arith::ArithmeticDialect>();
|
||||
|
||||
registerTransformOps<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
|
||||
>();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
|
||||
|
||||
void mlir::memref::registerTransformDialectExtension(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addExtensions<MemRefTransformDialectExtension>();
|
||||
}
|
|
@ -78,8 +78,8 @@ static Value getOrCreateValue(OpFoldResult res, OpBuilder &builder,
|
|||
// Returns success if the transformation happened and failure otherwise.
|
||||
// This is not a pattern as it requires propagating the new memref type to its
|
||||
// uses and requires updating subview ops.
|
||||
LogicalResult mlir::memref::multiBuffer(memref::AllocOp allocOp,
|
||||
unsigned multiplier) {
|
||||
FailureOr<memref::AllocOp> mlir::memref::multiBuffer(memref::AllocOp allocOp,
|
||||
unsigned multiplier) {
|
||||
DominanceInfo dom(allocOp->getParentOp());
|
||||
LoopLikeOpInterface candidateLoop;
|
||||
for (Operation *user : allocOp->getUsers()) {
|
||||
|
@ -142,5 +142,5 @@ LogicalResult mlir::memref::multiBuffer(memref::AllocOp allocOp,
|
|||
offsets, sizes, strides);
|
||||
replaceUsesAndPropagateType(allocOp, subview, builder);
|
||||
allocOp.erase();
|
||||
return success();
|
||||
return newAlloc;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
// RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect | FileCheck %s
|
||||
|
||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (((d0 - d1) floordiv d2) mod 2)>
|
||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
|
||||
|
||||
// CHECK-LABEL: func @multi_buffer
|
||||
func.func @multi_buffer(%in: memref<16xf32>) {
|
||||
// CHECK: %[[A:.*]] = memref.alloc() : memref<2x4xf32>
|
||||
// expected-remark @below {{transformed}}
|
||||
%tmp = memref.alloc() : memref<4xf32>
|
||||
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[C4:.*]] = arith.constant 4 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%c16 = arith.constant 16 : index
|
||||
|
||||
// CHECK: scf.for %[[IV:.*]] = %[[C0]]
|
||||
scf.for %i0 = %c0 to %c16 step %c4 {
|
||||
// CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]], %[[C0]], %[[C4]])
|
||||
// CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>>
|
||||
%1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
|
||||
// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4xf32, #[[$MAP1]]> to memref<4xf32, strided<[1], offset: ?>>
|
||||
memref.copy %1, %tmp : memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32>
|
||||
|
||||
"some_use"(%tmp) : (memref<4xf32>) ->()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["memref.alloc"]} in %arg1
|
||||
%1 = transform.memref.multibuffer %0 {factor = 2 : i64}
|
||||
// Verify that the returned handle is usable.
|
||||
transform.test_print_remark_at_operand %1, "transformed"
|
||||
}
|
Loading…
Reference in New Issue