From 4142932a834f0dca9e9ae0c3754f097ffa3fc1ef Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Mon, 20 Dec 2021 09:34:41 -0800 Subject: [PATCH] [mlir][Linalg] Move named op conversions out of canonicalizations. These conversions are better suited to be applied at whole tensor level. Applying these as canonicalizations end up triggering such canonicalizations at all levels of the stack which might be undesirable. For example some of the resulting code patterns wont bufferize in-place and need additional stack buffers. Best is to be more deliberate in when these canonicalizations apply. Differential Revision: https://reviews.llvm.org/D115912 --- mlir/include/mlir/Dialect/Linalg/Passes.h | 2 + mlir/include/mlir/Dialect/Linalg/Passes.td | 6 + .../Dialect/Linalg/Transforms/Transforms.h | 4 + mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 115 +------------ .../Dialect/Linalg/Transforms/CMakeLists.txt | 1 + .../Linalg/Transforms/NamedOpConversions.cpp | 160 ++++++++++++++++++ .../Dialect/Linalg/Transforms/PassDetail.h | 4 + mlir/test/Dialect/Linalg/canonicalize.mlir | 25 --- .../Dialect/Linalg/namedop_conversion.mlir | 24 +++ 9 files changed, 202 insertions(+), 139 deletions(-) create mode 100644 mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp create mode 100644 mlir/test/Dialect/Linalg/namedop_conversion.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index f21252800af8..8ebaaa8f8e4d 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -26,6 +26,8 @@ std::unique_ptr createLinalgFoldUnitExtentDimsPass(); std::unique_ptr createLinalgElementwiseOpFusionPass(); std::unique_ptr createFoldReshapeOpsByLinearizationPass(); +std::unique_ptr createLinalgNamedOpConversionPass(); + std::unique_ptr> createLinalgTilingPass( ArrayRef tileSizes = {}, linalg::LinalgTilingLoopType loopType = linalg::LinalgTilingLoopType::Loops, diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index 504bc562148f..5bcc8cc6e33f 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -100,6 +100,12 @@ def LinalgFoldReshapeOpsByLinearization : let dependentDialects = ["AffineDialect", "memref::MemRefDialect"]; } +def LinalgNamedOpConversion: Pass<"linalg-named-op-conversion"> { + let summary = "Convert from one named linalg op to another."; + let constructor = "mlir::createLinalgNamedOpConversionPass()"; + let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"]; +} + def LinalgLowerTiledLoopsToSCF : FunctionPass<"convert-linalg-tiled-loops-to-scf"> { let summary = "Lower linalg tiled loops to SCF loops and parallel loops"; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index c14259f7baba..34eef99dc729 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -86,6 +86,10 @@ void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns); void populateFoldUnitDimsReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns); +/// Patterns to convert from one named op to another. These can be seen as +/// canonicalizations of named ops into another named op. +void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns); + /// Populates the given list with patterns to bufferize linalg ops. void populateLinalgBufferizePatterns( bufferization::BufferizeTypeConverter &converter, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 6e8b08bcbb8e..26a0c9277b32 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2665,118 +2665,6 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern { } }; -static llvm::SmallVector getIndicesVector(int start, int end) { - return llvm::to_vector<2>(llvm::seq(start, end)); -} - -LogicalResult matchAndReplaceDepthwiseConv(Operation *operation, Value input, - Value kernel, Value iZp, Value kZp, - Value init, Attribute stride, - Attribute dilation, - PatternRewriter &rewriter) { - Location loc = operation->getLoc(); - auto linalgOp = dyn_cast(operation); - // Exit out on the memref version of this operation. - if (!linalgOp || !linalgOp.hasTensorSemantics()) - return failure(); - - auto result = operation->getResult(0); - - auto kernelTy = kernel.getType().dyn_cast(); - auto initTy = init.getType().dyn_cast(); - auto resultTy = result.getType().template dyn_cast(); - if (!kernelTy || !initTy || !resultTy) - return failure(); - - if (kernelTy.getDimSize(3) != 1) - return failure(); - - // Collapse kernel dims. - SmallVector collapsedKernelDims = { - getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)}; - auto newKernelTy = RankedTensorType::get( - {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)}, - kernelTy.getElementType()); - auto collapsedKernel = rewriter.create( - loc, newKernelTy, kernel, collapsedKernelDims); - - // Collapse init dims. - SmallVector collapsedInitDims = { - getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3), - getIndicesVector(3, 5)}; - auto newInitTy = - RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1), - initTy.getDimSize(2), initTy.getDimSize(3)}, - initTy.getElementType()); - auto collapsedInit = rewriter.create( - loc, newInitTy, init, collapsedInitDims); - - Value newConv; - if (isa(operation)) { - newConv = rewriter - .create( - loc, newInitTy, ValueRange{input, collapsedKernel}, - ValueRange{collapsedInit}, stride, dilation) - .getResult(0); - } else if (isa(operation)) { - newConv = - rewriter - .create( - loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp}, - ValueRange{collapsedInit}, stride, dilation) - .getResult(0); - } - - if (!newConv) - return failure(); - - // Expand dimensions back out to - rewriter.replaceOpWithNewOp( - operation, resultTy, newConv, collapsedInitDims); - return success(); -} - -struct SimplifyDepthwiseConvOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op, - PatternRewriter &rewriter) const override { - Operation *operation = op.getOperation(); - Value input = op.getInputOperand(0)->get(); - Value kernel = op.getInputOperand(1)->get(); - Value init = op.getOutputOperand(0)->get(); - - auto stride = op.strides(); - auto dilation = op.dilations(); - - return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr, - nullptr, init, stride, dilation, - rewriter); - } -}; - -struct SimplifyDepthwiseConvQOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op, - PatternRewriter &rewriter) const override { - Operation *operation = op.getOperation(); - Value input = op.getInputOperand(0)->get(); - Value kernel = op.getInputOperand(1)->get(); - Value iZp = op.getInputOperand(2)->get(); - Value kZp = op.getInputOperand(3)->get(); - Value init = op.getOutputOperand(0)->get(); - - auto stride = op.strides(); - auto dilation = op.dilations(); - - return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp, - init, stride, dilation, rewriter); - } -}; - } // namespace #define LINALGOP_FOLDERS(XXX) \ @@ -2798,8 +2686,7 @@ LINALGOP_FOLDERS(GenericOp) void LinalgDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { - results.add(getContext()); + results.add(getContext()); } Operation *LinalgDialect::materializeConstant(OpBuilder &builder, diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 5edede26bb72..5df61c73fcc6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms Interchange.cpp Loops.cpp LinalgStrategyPasses.cpp + NamedOpConversions.cpp Promotion.cpp Tiling.cpp Transforms.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp new file mode 100644 index 000000000000..bb38607d769a --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp @@ -0,0 +1,160 @@ +//===- NamedOpConversions.cpp - Implements conversions between named ops --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements conversions between named ops that can be seens as +// canonicalizations of named ops. +// +//===----------------------------------------------------------------------===// +#include "PassDetail.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "llvm/ADT/SmallVector.h" + +using namespace mlir; +using namespace mlir::linalg; + +static llvm::SmallVector getIndicesVector(int start, int end) { + return llvm::to_vector<2>(llvm::seq(start, end)); +} + +static LogicalResult +matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel, + Value iZp, Value kZp, Value init, Attribute stride, + Attribute dilation, PatternRewriter &rewriter) { + Location loc = operation->getLoc(); + auto linalgOp = dyn_cast(operation); + // Exit out on the memref version of this operation. + if (!linalgOp || !linalgOp.hasTensorSemantics()) + return failure(); + + auto result = operation->getResult(0); + + auto kernelTy = kernel.getType().dyn_cast(); + auto initTy = init.getType().dyn_cast(); + auto resultTy = result.getType().template dyn_cast(); + if (!kernelTy || !initTy || !resultTy) + return failure(); + + if (kernelTy.getDimSize(3) != 1) + return failure(); + + // Collapse kernel dims. + SmallVector collapsedKernelDims = { + getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)}; + auto newKernelTy = RankedTensorType::get( + {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)}, + kernelTy.getElementType()); + auto collapsedKernel = rewriter.create( + loc, newKernelTy, kernel, collapsedKernelDims); + + // Collapse init dims. + SmallVector collapsedInitDims = { + getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3), + getIndicesVector(3, 5)}; + auto newInitTy = + RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1), + initTy.getDimSize(2), initTy.getDimSize(3)}, + initTy.getElementType()); + auto collapsedInit = rewriter.create( + loc, newInitTy, init, collapsedInitDims); + + Value newConv; + if (isa(operation)) { + newConv = rewriter + .create( + loc, newInitTy, ValueRange{input, collapsedKernel}, + ValueRange{collapsedInit}, stride, dilation) + .getResult(0); + } else if (isa(operation)) { + newConv = + rewriter + .create( + loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp}, + ValueRange{collapsedInit}, stride, dilation) + .getResult(0); + } + + if (!newConv) + return failure(); + + // Expand dimensions back out to + rewriter.replaceOpWithNewOp( + operation, resultTy, newConv, collapsedInitDims); + return success(); +} + +namespace { +struct SimplifyDepthwiseConvOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op, + PatternRewriter &rewriter) const override { + Operation *operation = op.getOperation(); + Value input = op.getInputOperand(0)->get(); + Value kernel = op.getInputOperand(1)->get(); + Value init = op.getOutputOperand(0)->get(); + + auto stride = op.strides(); + auto dilation = op.dilations(); + + return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr, + nullptr, init, stride, dilation, + rewriter); + } +}; + +struct SimplifyDepthwiseConvQOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op, + PatternRewriter &rewriter) const override { + Operation *operation = op.getOperation(); + Value input = op.getInputOperand(0)->get(); + Value kernel = op.getInputOperand(1)->get(); + Value iZp = op.getInputOperand(2)->get(); + Value kZp = op.getInputOperand(3)->get(); + Value init = op.getOutputOperand(0)->get(); + + auto stride = op.strides(); + auto dilation = op.dilations(); + + return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp, + init, stride, dilation, rewriter); + } +}; + +struct LinalgNamedOpConversionPass + : public LinalgNamedOpConversionBase { + LinalgNamedOpConversionPass() = default; + LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) {} + + void runOnOperation() override { + Operation *op = getOperation(); + RewritePatternSet patterns(op->getContext()); + populateLinalgNamedOpConversionPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +void mlir::linalg::populateLinalgNamedOpConversionPatterns( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + +std::unique_ptr mlir::createLinalgNamedOpConversionPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h index b499dbbf0322..78cb590f0697 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h @@ -38,6 +38,10 @@ namespace memref { class MemRefDialect; } // namespace memref +namespace tensor { +class TensorDialect; +} // namespace tensor + namespace vector { class VectorDialect; } // namespace vector diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 5465fee05f98..a6913d6f06e2 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -758,28 +758,3 @@ func @dim_of_tiled_loop_result_no_canonicalize(%arg0: tensor, %arg1: te %r2 = tensor.dim %r, %c0 : tensor return %r2 : index } - -// ----- - -// CHECK-LABEL: @depthwise_conv -func @depthwise_conv(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]] - // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]] - // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor, tensor) outs(%[[INIT]] : tensor) - // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]] - %0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor - return %0 : tensor -} - - -// ----- - -// CHECK-LABEL: @depthwise_conv_q -func @depthwise_conv_q(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3 : i32, %arg4 : i32) -> tensor { - // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]] - // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]] - // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor, tensor, i32, i32) outs(%[[INIT]] : tensor) - // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]] - %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor, tensor, i32, i32) outs(%arg2 : tensor) -> tensor - return %0 : tensor -} diff --git a/mlir/test/Dialect/Linalg/namedop_conversion.mlir b/mlir/test/Dialect/Linalg/namedop_conversion.mlir new file mode 100644 index 000000000000..5f33f650930e --- /dev/null +++ b/mlir/test/Dialect/Linalg/namedop_conversion.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s -linalg-named-op-conversion -split-input-file | FileCheck %s + +// CHECK-LABEL: @depthwise_conv +func @depthwise_conv(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]] + // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]] + // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor, tensor) outs(%[[INIT]] : tensor) + // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]] + %0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + + +// ----- + +// CHECK-LABEL: @depthwise_conv_q +func @depthwise_conv_q(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3 : i32, %arg4 : i32) -> tensor { + // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]] + // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]] + // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor, tensor, i32, i32) outs(%[[INIT]] : tensor) + // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]] + %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor, tensor, i32, i32) outs(%arg2 : tensor) -> tensor + return %0 : tensor +}