[mlir][linalg][transform] Add TileOp to transform dialect

This commit adds a tiling op to the transform dialect as an external op.

Differential Revision: https://reviews.llvm.org/D124661
This commit is contained in:
Matthias Springer 2022-04-29 21:34:41 +09:00
parent e66127e69b
commit 3c2a74a3ae
12 changed files with 414 additions and 9 deletions

View File

@ -1,4 +1,5 @@
add_subdirectory(IR)
add_subdirectory(TransformOps)
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Linalg)

View File

@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS LinalgTransformOps.td)
mlir_tablegen(LinalgTransformOps.h.inc -gen-op-decls)
mlir_tablegen(LinalgTransformOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRLinalgTransformOpsIncGen)

View File

@ -0,0 +1,30 @@
//===- LinalgTransformOps.h - Linalg transform ops --------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
#define MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"
//===----------------------------------------------------------------------===//
// Linalg Transform Operations
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h.inc"
namespace mlir {
class DialectRegistry;
namespace linalg {
void registerTransformDialectExtension(DialectRegistry &registry);
} // namespace linalg
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H

View File

@ -0,0 +1,45 @@
//===- LinalgTransformOps.td - Linalg transform ops --------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef LINALG_TRANSFORM_OPS
#define LINALG_TRANSFORM_OPS
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
def TileOp : Op<Transform_Dialect, "structured.tile",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let description = [{
Indicates that the given `target` op should be tiled with the options
provided as attributes. This transform generates a loop nest with a smaller
("tiled") target operation in its body. Currently limited to LinalgOps.
`sizes` are the tile sizes. A tile size of `0` indicates that the
respective dimension should not be tiled. No loop will be generated for such
dimensions. If all tile sizes are `0`, this transform is effectively a
no-op.
This op returns handles to the tiled op (in the generated loop nest) and the
generated loops. The number of loops is the number of non-zero tile sizes.
}];
let arguments = (ins PDL_Operation:$target,
DefaultValuedAttr<I64ArrayAttr, "{}">:$sizes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$interchange);
let results = (outs PDL_Operation:$tiled_linalg_op,
Variadic<PDL_Operation>:$loops);
let hasCustomAssemblyFormat = 1;
}
#endif // LINALG_TRANSFORM_OPS

View File

@ -33,6 +33,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Math/IR/Math.h"
@ -101,6 +102,11 @@ inline void registerAllDialects(DialectRegistry &registry) {
tosa::TosaDialect,
x86vector::X86VectorDialect>();
// clang-format on
// Register all dialect extensions.
linalg::registerTransformDialectExtension(registry);
// Register all external models.
arith::registerBufferizableOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);

View File

@ -1,4 +1,5 @@
add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(TransformOps)
add_subdirectory(Transforms)
add_subdirectory(Utils)

View File

@ -0,0 +1,18 @@
add_mlir_dialect_library(MLIRLinalgTransformOps
LinalgTransformOps.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg/TransformOps
DEPENDS
MLIRLinalgTransformOpsIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRLinalg
MLIRLinalgTransforms
MLIRParser
MLIRPDL
MLIRSideEffectInterfaces
MLIRTransformDialect
)

View File

@ -0,0 +1,198 @@
//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Parser/Parser.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::transform;
/// Extracts a vector of int64_t from an array attribute. Asserts if the
/// attribute contains values other than integers.
static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
SmallVector<int64_t> result;
result.reserve(attr.size());
for (APInt value : attr.getAsValueRange<IntegerAttr>())
result.push_back(value.getSExtValue());
return result;
}
/// Extracts a vector of unsigned from an array attribute. Asserts if the
/// attribute contains values other than intergers. May truncate.
static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
SmallVector<unsigned> result;
result.reserve(attr.size());
for (APInt value : attr.getAsValueRange<IntegerAttr>())
result.push_back(value.getZExtValue());
return result;
}
namespace {
/// A simple pattern rewriter that implements no special logic.
class SimpleRewriter : public PatternRewriter {
public:
SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
};
} // namespace
//===----------------------------------------------------------------------===//
// TileOp
//===----------------------------------------------------------------------===//
/// Apply a tiling transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
static LogicalResult
applyTilingToAll(Operation *transformOp, Value target,
ArrayRef<int64_t> tileSizes,
transform::TransformResults &transformResults,
transform::TransformState &state,
function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
// Number of loops: Number of tiles sizes that are not zero.
size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
// All payload ops. These should all be LinalgOps for now.
ArrayRef<Operation *> payloadOps = state.getPayloadOps(target);
SmallVector<Operation *> tiledLinalgOps;
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
for (unsigned int i = 0; i < numLoops; ++i)
loopOps[i].reserve(payloadOps.size());
for (Operation *target : payloadOps) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
if (!linalgOp)
return transformOp->emitError("only LinalgOps are supported");
FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
if (failed(tiled))
return failure();
tiledLinalgOps.push_back(tiled->op);
if (tiled->loops.size() != numLoops)
// Not enough loops were generated. This usually means that the input size
// was smaller than the tiling size.
// TODO: LinalgTilingPattern should return failure().
return failure();
for (unsigned int i = 0; i < numLoops; ++i)
loopOps[i].push_back(tiled->loops[i]);
}
transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
for (unsigned int i = 0; i < numLoops; ++i)
transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
return success();
}
LogicalResult transform::TileOp::apply(TransformResults &transformResults,
TransformState &state) {
LinalgTilingOptions tilingOptions;
SmallVector<int64_t> tileSizes = extractI64Array(getSizes());
if (!tileSizes.empty())
tilingOptions.setTileSizes(tileSizes);
tilingOptions.setInterchange(extractUIntArray(getInterchange()));
LinalgTilingPattern pattern(getContext(), tilingOptions);
return applyTilingToAll(getOperation(), getTarget(), tileSizes,
transformResults, state, [&](LinalgOp linalgOp) {
SimpleRewriter rewriter(linalgOp.getContext());
return pattern.returningMatchAndRewrite(linalgOp,
rewriter);
});
}
ParseResult transform::TileOp::parse(OpAsmParser &parser,
OperationState &result) {
StringRef sizesAttrName = TileOp::getSizesAttrName(result.name).getValue();
OpAsmParser::UnresolvedOperand targetOperand;
SMLoc opLoc;
parser.getCurrentLocation(&opLoc);
if (parser.parseOperand(targetOperand))
return parser.emitError(opLoc, "expected 'target' operand");
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
Attribute sizesAttr = result.attributes.get(sizesAttrName);
if (!sizesAttr)
return parser.emitError(opLoc)
<< "expected '" << sizesAttrName << "' attribute";
auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
if (!sizesArrayAttr)
return parser.emitError(opLoc)
<< "'" << sizesAttrName << "' attribute must be an array";
Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
size_t numExpectedLoops =
sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
return failure();
return success();
}
void TileOp::print(OpAsmPrinter &p) {
p << ' ';
p << getTarget();
p.printOptionalAttrDict((*this)->getAttrs());
}
void TileOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
// `target` arg is consumed and can no longer be used.
effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
TransformMappingResource::get());
effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
TransformMappingResource::get());
for (Value r : getResults()) {
effects.emplace_back(MemoryEffects::Write::get(), r,
TransformMappingResource::get());
effects.emplace_back(MemoryEffects::Allocate::get(), r,
TransformMappingResource::get());
}
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
}
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
namespace {
/// Registers new ops and declares PDL as dependent dialect since the additional
/// ops are using PDL types for operands and results.
class LinalgTransformDialectExtension
: public transform::TransformDialectExtension<
LinalgTransformDialectExtension> {
public:
LinalgTransformDialectExtension() {
declareDependentDialect<pdl::PDLDialect>();
declareDependentDialect<scf::SCFDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
>();
}
};
} // namespace
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
void mlir::linalg::registerTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<LinalgTransformDialectExtension>();
}

View File

@ -168,8 +168,8 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
// Shift all IndexOp results by the tile offset.
SmallVector<Value> allIvs;
transform(loopRanges, std::back_inserter(allIvs),
[](Range range) { return range.offset; });
llvm::transform(loopRanges, std::back_inserter(allIvs),
[](Range range) { return range.offset; });
addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs);
return clonedOp;

View File

@ -87,10 +87,11 @@ getTiledProducerLoops(OpResult producerResult,
assert(tiledProducerIndexingSubMap.isProjectedPermutation() &&
"expect slice and producer loop dimensions map one-to-one");
SmallVector<int64_t> tiledProducerLoopIndices;
transform(llvm::seq<unsigned>(0, tiledProducerIndexingSubMap.getNumResults()),
std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) {
return tiledProducerIndexingSubMap.getDimPosition(idx);
});
llvm::transform(
llvm::seq<unsigned>(0, tiledProducerIndexingSubMap.getNumResults()),
std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) {
return tiledProducerIndexingSubMap.getDimPosition(idx);
});
return tiledProducerLoopIndices;
}
@ -141,9 +142,9 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
// Obtain the `producerOp` loop bounds and the `sliceOp` ranges.
SmallVector<Value> producerLoopBounds;
transform(producerOp.createLoopRanges(b, loc),
std::back_inserter(producerLoopBounds),
[](Range range) { return range.size; });
llvm::transform(producerOp.createLoopRanges(b, loc),
std::back_inserter(producerLoopBounds),
[](Range range) { return range.size; });
SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges(b, loc);
// Tile the producer operands given the `sliceOp` ranges. Iterate the

View File

@ -0,0 +1,46 @@
// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%1, %loops:3 = transform.structured.tile %0 {sizes = [4, 4, 4]}
}
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %0 with "transform.dialect"
}
}
// CHECK-LABEL: func @tile_linalg_matmul(
// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<128x128xf32>
// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<128x128xf32>
// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<128x128xf32>
// CHECK-SAME: -> tensor<128x128xf32> {
func @tile_linalg_matmul(
%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
-> tensor<128x128xf32> {
// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<128x128xf32>) {
// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<128x128xf32>) {
// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<128x128xf32>) {
// CHECK: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
// CHECK: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
// CHECK: %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<4x4xf32>, tensor<4x4xf32>)
// CHECK-SAME: outs(%[[sTC]] : tensor<4x4xf32>) -> tensor<4x4xf32>
// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor<4x4xf32> into tensor<128x128xf32>
// CHECK: scf.yield %[[TD]] : tensor<128x128xf32>
// CHECK: scf.yield %[[TD2]] : tensor<128x128xf32>
// CHECK: scf.yield %[[TD1]] : tensor<128x128xf32>
%0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
outs(%arg2: tensor<128x128xf32>)
-> tensor<128x128xf32>
// CHECK: return %[[TD0]] : tensor<128x128xf32>
return %0 : tensor<128x128xf32>
}

View File

@ -6090,6 +6090,7 @@ cc_library(
":LinalgToLLVM",
":LinalgToSPIRV",
":LinalgToStandard",
":LinalgTransformOps",
":LinalgTransforms",
":MLProgramDialect",
":MathDialect",
@ -6905,6 +6906,18 @@ td_library(
],
)
td_library(
name = "LinalgTransformOpsTdFiles",
srcs = [
"include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td",
],
includes = ["include"],
deps = [
":PDLDialectTdFiles",
":TransformDialectTdFiles",
],
)
gentbl_cc_library(
name = "LinalgOpsIncGen",
strip_include_prefix = "include",
@ -6953,6 +6966,26 @@ gentbl_cc_library(
deps = [":LinalgOpsTdFiles"],
)
gentbl_cc_library(
name = "LinalgTransformOpsIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-op-decls"],
"include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h.inc",
),
(
["-gen-op-defs"],
"include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td",
deps = [
":LinalgTransformOpsTdFiles",
],
)
genlinalg(
name = "LinalgNamedStructuredOpsYamlIncGen",
src = "include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml",
@ -7200,6 +7233,28 @@ cc_library(
],
)
cc_library(
name = "LinalgTransformOps",
srcs = [
"lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp",
],
hdrs = [
"include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h",
],
includes = ["include"],
deps = [
":IR",
":LinalgOps",
":LinalgTransformOpsIncGen",
":LinalgTransforms",
":PDLDialect",
":Parser",
":SideEffectInterfaces",
":TransformDialect",
"//llvm:Support",
],
)
gentbl_cc_library(
name = "LinalgPassIncGen",
strip_include_prefix = "include",