[mlir][Linalg] Move named op conversions out of canonicalizations.

These conversions are better suited to be applied at whole tensor
level. Applying these as canonicalizations end up triggering such
canonicalizations at all levels of the stack which might be
undesirable. For example some of the resulting code patterns wont
bufferize in-place and need additional stack buffers. Best is to be
more deliberate in when these canonicalizations apply.

Differential Revision: https://reviews.llvm.org/D115912
This commit is contained in:
MaheshRavishankar 2021-12-20 09:34:41 -08:00
parent bee5bc9075
commit 4142932a83
9 changed files with 202 additions and 139 deletions

View File

@ -26,6 +26,8 @@ std::unique_ptr<Pass> createLinalgFoldUnitExtentDimsPass();
std::unique_ptr<Pass> createLinalgElementwiseOpFusionPass();
std::unique_ptr<Pass> createFoldReshapeOpsByLinearizationPass();
std::unique_ptr<Pass> createLinalgNamedOpConversionPass();
std::unique_ptr<OperationPass<FuncOp>> createLinalgTilingPass(
ArrayRef<int64_t> tileSizes = {},
linalg::LinalgTilingLoopType loopType = linalg::LinalgTilingLoopType::Loops,

View File

@ -100,6 +100,12 @@ def LinalgFoldReshapeOpsByLinearization :
let dependentDialects = ["AffineDialect", "memref::MemRefDialect"];
}
def LinalgNamedOpConversion: Pass<"linalg-named-op-conversion"> {
let summary = "Convert from one named linalg op to another.";
let constructor = "mlir::createLinalgNamedOpConversionPass()";
let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
}
def LinalgLowerTiledLoopsToSCF
: FunctionPass<"convert-linalg-tiled-loops-to-scf"> {
let summary = "Lower linalg tiled loops to SCF loops and parallel loops";

View File

@ -86,6 +86,10 @@ void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns);
/// Patterns to convert from one named op to another. These can be seen as
/// canonicalizations of named ops into another named op.
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
/// Populates the given list with patterns to bufferize linalg ops.
void populateLinalgBufferizePatterns(
bufferization::BufferizeTypeConverter &converter,

View File

@ -2665,118 +2665,6 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
}
};
static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
}
LogicalResult matchAndReplaceDepthwiseConv(Operation *operation, Value input,
Value kernel, Value iZp, Value kZp,
Value init, Attribute stride,
Attribute dilation,
PatternRewriter &rewriter) {
Location loc = operation->getLoc();
auto linalgOp = dyn_cast<LinalgOp>(operation);
// Exit out on the memref version of this operation.
if (!linalgOp || !linalgOp.hasTensorSemantics())
return failure();
auto result = operation->getResult(0);
auto kernelTy = kernel.getType().dyn_cast<RankedTensorType>();
auto initTy = init.getType().dyn_cast<RankedTensorType>();
auto resultTy = result.getType().template dyn_cast<RankedTensorType>();
if (!kernelTy || !initTy || !resultTy)
return failure();
if (kernelTy.getDimSize(3) != 1)
return failure();
// Collapse kernel dims.
SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)};
auto newKernelTy = RankedTensorType::get(
{kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
kernelTy.getElementType());
auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>(
loc, newKernelTy, kernel, collapsedKernelDims);
// Collapse init dims.
SmallVector<ReassociationIndices, 4> collapsedInitDims = {
getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3),
getIndicesVector(3, 5)};
auto newInitTy =
RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
initTy.getDimSize(2), initTy.getDimSize(3)},
initTy.getElementType());
auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
loc, newInitTy, init, collapsedInitDims);
Value newConv;
if (isa<DepthwiseConv2DNhwcHwcmOp>(operation)) {
newConv = rewriter
.create<DepthwiseConv2DNhwcHwcOp>(
loc, newInitTy, ValueRange{input, collapsedKernel},
ValueRange{collapsedInit}, stride, dilation)
.getResult(0);
} else if (isa<DepthwiseConv2DNhwcHwcmQOp>(operation)) {
newConv =
rewriter
.create<DepthwiseConv2DNhwcHwcQOp>(
loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
ValueRange{collapsedInit}, stride, dilation)
.getResult(0);
}
if (!newConv)
return failure();
// Expand dimensions back out to
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
operation, resultTy, newConv, collapsedInitDims);
return success();
}
struct SimplifyDepthwiseConvOp
: public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
PatternRewriter &rewriter) const override {
Operation *operation = op.getOperation();
Value input = op.getInputOperand(0)->get();
Value kernel = op.getInputOperand(1)->get();
Value init = op.getOutputOperand(0)->get();
auto stride = op.strides();
auto dilation = op.dilations();
return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
nullptr, init, stride, dilation,
rewriter);
}
};
struct SimplifyDepthwiseConvQOp
: public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
PatternRewriter &rewriter) const override {
Operation *operation = op.getOperation();
Value input = op.getInputOperand(0)->get();
Value kernel = op.getInputOperand(1)->get();
Value iZp = op.getInputOperand(2)->get();
Value kZp = op.getInputOperand(3)->get();
Value init = op.getOutputOperand(0)->get();
auto stride = op.strides();
auto dilation = op.dilations();
return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
init, stride, dilation, rewriter);
}
};
} // namespace
#define LINALGOP_FOLDERS(XXX) \
@ -2798,8 +2686,7 @@ LINALGOP_FOLDERS(GenericOp)
void LinalgDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add<EraseDeadLinalgOp, FoldTensorCastOp, SimplifyDepthwiseConvOp,
SimplifyDepthwiseConvQOp>(getContext());
results.add<EraseDeadLinalgOp, FoldTensorCastOp>(getContext());
}
Operation *LinalgDialect::materializeConstant(OpBuilder &builder,

View File

@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Interchange.cpp
Loops.cpp
LinalgStrategyPasses.cpp
NamedOpConversions.cpp
Promotion.cpp
Tiling.cpp
Transforms.cpp

View File

@ -0,0 +1,160 @@
//===- NamedOpConversions.cpp - Implements conversions between named ops --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements conversions between named ops that can be seens as
// canonicalizations of named ops.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"
using namespace mlir;
using namespace mlir::linalg;
static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
}
static LogicalResult
matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
Value iZp, Value kZp, Value init, Attribute stride,
Attribute dilation, PatternRewriter &rewriter) {
Location loc = operation->getLoc();
auto linalgOp = dyn_cast<LinalgOp>(operation);
// Exit out on the memref version of this operation.
if (!linalgOp || !linalgOp.hasTensorSemantics())
return failure();
auto result = operation->getResult(0);
auto kernelTy = kernel.getType().dyn_cast<RankedTensorType>();
auto initTy = init.getType().dyn_cast<RankedTensorType>();
auto resultTy = result.getType().template dyn_cast<RankedTensorType>();
if (!kernelTy || !initTy || !resultTy)
return failure();
if (kernelTy.getDimSize(3) != 1)
return failure();
// Collapse kernel dims.
SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)};
auto newKernelTy = RankedTensorType::get(
{kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
kernelTy.getElementType());
auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>(
loc, newKernelTy, kernel, collapsedKernelDims);
// Collapse init dims.
SmallVector<ReassociationIndices, 4> collapsedInitDims = {
getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3),
getIndicesVector(3, 5)};
auto newInitTy =
RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
initTy.getDimSize(2), initTy.getDimSize(3)},
initTy.getElementType());
auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
loc, newInitTy, init, collapsedInitDims);
Value newConv;
if (isa<DepthwiseConv2DNhwcHwcmOp>(operation)) {
newConv = rewriter
.create<DepthwiseConv2DNhwcHwcOp>(
loc, newInitTy, ValueRange{input, collapsedKernel},
ValueRange{collapsedInit}, stride, dilation)
.getResult(0);
} else if (isa<DepthwiseConv2DNhwcHwcmQOp>(operation)) {
newConv =
rewriter
.create<DepthwiseConv2DNhwcHwcQOp>(
loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
ValueRange{collapsedInit}, stride, dilation)
.getResult(0);
}
if (!newConv)
return failure();
// Expand dimensions back out to
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
operation, resultTy, newConv, collapsedInitDims);
return success();
}
namespace {
struct SimplifyDepthwiseConvOp
: public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
PatternRewriter &rewriter) const override {
Operation *operation = op.getOperation();
Value input = op.getInputOperand(0)->get();
Value kernel = op.getInputOperand(1)->get();
Value init = op.getOutputOperand(0)->get();
auto stride = op.strides();
auto dilation = op.dilations();
return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
nullptr, init, stride, dilation,
rewriter);
}
};
struct SimplifyDepthwiseConvQOp
: public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
PatternRewriter &rewriter) const override {
Operation *operation = op.getOperation();
Value input = op.getInputOperand(0)->get();
Value kernel = op.getInputOperand(1)->get();
Value iZp = op.getInputOperand(2)->get();
Value kZp = op.getInputOperand(3)->get();
Value init = op.getOutputOperand(0)->get();
auto stride = op.strides();
auto dilation = op.dilations();
return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
init, stride, dilation, rewriter);
}
};
struct LinalgNamedOpConversionPass
: public LinalgNamedOpConversionBase<LinalgNamedOpConversionPass> {
LinalgNamedOpConversionPass() = default;
LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) {}
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
populateLinalgNamedOpConversionPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
void mlir::linalg::populateLinalgNamedOpConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
patterns.getContext());
}
std::unique_ptr<Pass> mlir::createLinalgNamedOpConversionPass() {
return std::make_unique<LinalgNamedOpConversionPass>();
}

View File

@ -38,6 +38,10 @@ namespace memref {
class MemRefDialect;
} // namespace memref
namespace tensor {
class TensorDialect;
} // namespace tensor
namespace vector {
class VectorDialect;
} // namespace vector

View File

@ -758,28 +758,3 @@ func @dim_of_tiled_loop_result_no_canonicalize(%arg0: tensor<?x?xf32>, %arg1: te
%r2 = tensor.dim %r, %c0 : tensor<?x?xf32>
return %r2 : index
}
// -----
// CHECK-LABEL: @depthwise_conv
func @depthwise_conv(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x1xf32>, %arg2: tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32> {
// CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
// CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
// CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?x?xf32>)
// CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
%0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x1xf32>) outs(%arg2 : tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32>
return %0 : tensor<?x?x?x?x1xf32>
}
// -----
// CHECK-LABEL: @depthwise_conv_q
func @depthwise_conv_q(%arg0: tensor<?x?x?x?xi8>, %arg1: tensor<?x?x?x1xi8>, %arg2: tensor<?x?x?x?x1xi32>, %arg3 : i32, %arg4 : i32) -> tensor<?x?x?x?x1xi32> {
// CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
// CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
// CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32) outs(%[[INIT]] : tensor<?x?x?x?xi32>)
// CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
%0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x1xi8>, i32, i32) outs(%arg2 : tensor<?x?x?x?x1xi32>) -> tensor<?x?x?x?x1xi32>
return %0 : tensor<?x?x?x?x1xi32>
}

View File

@ -0,0 +1,24 @@
// RUN: mlir-opt %s -linalg-named-op-conversion -split-input-file | FileCheck %s
// CHECK-LABEL: @depthwise_conv
func @depthwise_conv(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x1xf32>, %arg2: tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32> {
// CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
// CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
// CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?x?xf32>)
// CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
%0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x1xf32>) outs(%arg2 : tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32>
return %0 : tensor<?x?x?x?x1xf32>
}
// -----
// CHECK-LABEL: @depthwise_conv_q
func @depthwise_conv_q(%arg0: tensor<?x?x?x?xi8>, %arg1: tensor<?x?x?x1xi8>, %arg2: tensor<?x?x?x?x1xi32>, %arg3 : i32, %arg4 : i32) -> tensor<?x?x?x?x1xi32> {
// CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
// CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
// CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32) outs(%[[INIT]] : tensor<?x?x?x?xi32>)
// CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
%0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x1xi8>, i32, i32) outs(%arg2 : tensor<?x?x?x?x1xi32>) -> tensor<?x?x?x?x1xi32>
return %0 : tensor<?x?x?x?x1xi32>
}