diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index f21252800af8..8ebaaa8f8e4d 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -26,6 +26,8 @@ std::unique_ptr createLinalgFoldUnitExtentDimsPass(); std::unique_ptr createLinalgElementwiseOpFusionPass(); std::unique_ptr createFoldReshapeOpsByLinearizationPass(); +std::unique_ptr createLinalgNamedOpConversionPass(); + std::unique_ptr> createLinalgTilingPass( ArrayRef tileSizes = {}, linalg::LinalgTilingLoopType loopType = linalg::LinalgTilingLoopType::Loops, diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index 504bc562148f..5bcc8cc6e33f 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -100,6 +100,12 @@ def LinalgFoldReshapeOpsByLinearization : let dependentDialects = ["AffineDialect", "memref::MemRefDialect"]; } +def LinalgNamedOpConversion: Pass<"linalg-named-op-conversion"> { + let summary = "Convert from one named linalg op to another."; + let constructor = "mlir::createLinalgNamedOpConversionPass()"; + let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"]; +} + def LinalgLowerTiledLoopsToSCF : FunctionPass<"convert-linalg-tiled-loops-to-scf"> { let summary = "Lower linalg tiled loops to SCF loops and parallel loops"; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index c14259f7baba..34eef99dc729 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -86,6 +86,10 @@ void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns); void populateFoldUnitDimsReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns); +/// Patterns to convert from one named op to another. These can be seen as +/// canonicalizations of named ops into another named op. +void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns); + /// Populates the given list with patterns to bufferize linalg ops. void populateLinalgBufferizePatterns( bufferization::BufferizeTypeConverter &converter, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 6e8b08bcbb8e..26a0c9277b32 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2665,118 +2665,6 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern { } }; -static llvm::SmallVector getIndicesVector(int start, int end) { - return llvm::to_vector<2>(llvm::seq(start, end)); -} - -LogicalResult matchAndReplaceDepthwiseConv(Operation *operation, Value input, - Value kernel, Value iZp, Value kZp, - Value init, Attribute stride, - Attribute dilation, - PatternRewriter &rewriter) { - Location loc = operation->getLoc(); - auto linalgOp = dyn_cast(operation); - // Exit out on the memref version of this operation. - if (!linalgOp || !linalgOp.hasTensorSemantics()) - return failure(); - - auto result = operation->getResult(0); - - auto kernelTy = kernel.getType().dyn_cast(); - auto initTy = init.getType().dyn_cast(); - auto resultTy = result.getType().template dyn_cast(); - if (!kernelTy || !initTy || !resultTy) - return failure(); - - if (kernelTy.getDimSize(3) != 1) - return failure(); - - // Collapse kernel dims. - SmallVector collapsedKernelDims = { - getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)}; - auto newKernelTy = RankedTensorType::get( - {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)}, - kernelTy.getElementType()); - auto collapsedKernel = rewriter.create( - loc, newKernelTy, kernel, collapsedKernelDims); - - // Collapse init dims. - SmallVector collapsedInitDims = { - getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3), - getIndicesVector(3, 5)}; - auto newInitTy = - RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1), - initTy.getDimSize(2), initTy.getDimSize(3)}, - initTy.getElementType()); - auto collapsedInit = rewriter.create( - loc, newInitTy, init, collapsedInitDims); - - Value newConv; - if (isa(operation)) { - newConv = rewriter - .create( - loc, newInitTy, ValueRange{input, collapsedKernel}, - ValueRange{collapsedInit}, stride, dilation) - .getResult(0); - } else if (isa(operation)) { - newConv = - rewriter - .create( - loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp}, - ValueRange{collapsedInit}, stride, dilation) - .getResult(0); - } - - if (!newConv) - return failure(); - - // Expand dimensions back out to - rewriter.replaceOpWithNewOp( - operation, resultTy, newConv, collapsedInitDims); - return success(); -} - -struct SimplifyDepthwiseConvOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op, - PatternRewriter &rewriter) const override { - Operation *operation = op.getOperation(); - Value input = op.getInputOperand(0)->get(); - Value kernel = op.getInputOperand(1)->get(); - Value init = op.getOutputOperand(0)->get(); - - auto stride = op.strides(); - auto dilation = op.dilations(); - - return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr, - nullptr, init, stride, dilation, - rewriter); - } -}; - -struct SimplifyDepthwiseConvQOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op, - PatternRewriter &rewriter) const override { - Operation *operation = op.getOperation(); - Value input = op.getInputOperand(0)->get(); - Value kernel = op.getInputOperand(1)->get(); - Value iZp = op.getInputOperand(2)->get(); - Value kZp = op.getInputOperand(3)->get(); - Value init = op.getOutputOperand(0)->get(); - - auto stride = op.strides(); - auto dilation = op.dilations(); - - return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp, - init, stride, dilation, rewriter); - } -}; - } // namespace #define LINALGOP_FOLDERS(XXX) \ @@ -2798,8 +2686,7 @@ LINALGOP_FOLDERS(GenericOp) void LinalgDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { - results.add(getContext()); + results.add(getContext()); } Operation *LinalgDialect::materializeConstant(OpBuilder &builder, diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 5edede26bb72..5df61c73fcc6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms Interchange.cpp Loops.cpp LinalgStrategyPasses.cpp + NamedOpConversions.cpp Promotion.cpp Tiling.cpp Transforms.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp new file mode 100644 index 000000000000..bb38607d769a --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp @@ -0,0 +1,160 @@ +//===- NamedOpConversions.cpp - Implements conversions between named ops --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements conversions between named ops that can be seens as +// canonicalizations of named ops. +// +//===----------------------------------------------------------------------===// +#include "PassDetail.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "llvm/ADT/SmallVector.h" + +using namespace mlir; +using namespace mlir::linalg; + +static llvm::SmallVector getIndicesVector(int start, int end) { + return llvm::to_vector<2>(llvm::seq(start, end)); +} + +static LogicalResult +matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel, + Value iZp, Value kZp, Value init, Attribute stride, + Attribute dilation, PatternRewriter &rewriter) { + Location loc = operation->getLoc(); + auto linalgOp = dyn_cast(operation); + // Exit out on the memref version of this operation. + if (!linalgOp || !linalgOp.hasTensorSemantics()) + return failure(); + + auto result = operation->getResult(0); + + auto kernelTy = kernel.getType().dyn_cast(); + auto initTy = init.getType().dyn_cast(); + auto resultTy = result.getType().template dyn_cast(); + if (!kernelTy || !initTy || !resultTy) + return failure(); + + if (kernelTy.getDimSize(3) != 1) + return failure(); + + // Collapse kernel dims. + SmallVector collapsedKernelDims = { + getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)}; + auto newKernelTy = RankedTensorType::get( + {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)}, + kernelTy.getElementType()); + auto collapsedKernel = rewriter.create( + loc, newKernelTy, kernel, collapsedKernelDims); + + // Collapse init dims. + SmallVector collapsedInitDims = { + getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3), + getIndicesVector(3, 5)}; + auto newInitTy = + RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1), + initTy.getDimSize(2), initTy.getDimSize(3)}, + initTy.getElementType()); + auto collapsedInit = rewriter.create( + loc, newInitTy, init, collapsedInitDims); + + Value newConv; + if (isa(operation)) { + newConv = rewriter + .create( + loc, newInitTy, ValueRange{input, collapsedKernel}, + ValueRange{collapsedInit}, stride, dilation) + .getResult(0); + } else if (isa(operation)) { + newConv = + rewriter + .create( + loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp}, + ValueRange{collapsedInit}, stride, dilation) + .getResult(0); + } + + if (!newConv) + return failure(); + + // Expand dimensions back out to + rewriter.replaceOpWithNewOp( + operation, resultTy, newConv, collapsedInitDims); + return success(); +} + +namespace { +struct SimplifyDepthwiseConvOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op, + PatternRewriter &rewriter) const override { + Operation *operation = op.getOperation(); + Value input = op.getInputOperand(0)->get(); + Value kernel = op.getInputOperand(1)->get(); + Value init = op.getOutputOperand(0)->get(); + + auto stride = op.strides(); + auto dilation = op.dilations(); + + return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr, + nullptr, init, stride, dilation, + rewriter); + } +}; + +struct SimplifyDepthwiseConvQOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op, + PatternRewriter &rewriter) const override { + Operation *operation = op.getOperation(); + Value input = op.getInputOperand(0)->get(); + Value kernel = op.getInputOperand(1)->get(); + Value iZp = op.getInputOperand(2)->get(); + Value kZp = op.getInputOperand(3)->get(); + Value init = op.getOutputOperand(0)->get(); + + auto stride = op.strides(); + auto dilation = op.dilations(); + + return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp, + init, stride, dilation, rewriter); + } +}; + +struct LinalgNamedOpConversionPass + : public LinalgNamedOpConversionBase { + LinalgNamedOpConversionPass() = default; + LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) {} + + void runOnOperation() override { + Operation *op = getOperation(); + RewritePatternSet patterns(op->getContext()); + populateLinalgNamedOpConversionPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +void mlir::linalg::populateLinalgNamedOpConversionPatterns( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + +std::unique_ptr mlir::createLinalgNamedOpConversionPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h index b499dbbf0322..78cb590f0697 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h @@ -38,6 +38,10 @@ namespace memref { class MemRefDialect; } // namespace memref +namespace tensor { +class TensorDialect; +} // namespace tensor + namespace vector { class VectorDialect; } // namespace vector diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 5465fee05f98..a6913d6f06e2 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -758,28 +758,3 @@ func @dim_of_tiled_loop_result_no_canonicalize(%arg0: tensor, %arg1: te %r2 = tensor.dim %r, %c0 : tensor return %r2 : index } - -// ----- - -// CHECK-LABEL: @depthwise_conv -func @depthwise_conv(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]] - // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]] - // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor, tensor) outs(%[[INIT]] : tensor) - // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]] - %0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor - return %0 : tensor -} - - -// ----- - -// CHECK-LABEL: @depthwise_conv_q -func @depthwise_conv_q(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3 : i32, %arg4 : i32) -> tensor { - // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]] - // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]] - // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor, tensor, i32, i32) outs(%[[INIT]] : tensor) - // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]] - %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor, tensor, i32, i32) outs(%arg2 : tensor) -> tensor - return %0 : tensor -} diff --git a/mlir/test/Dialect/Linalg/namedop_conversion.mlir b/mlir/test/Dialect/Linalg/namedop_conversion.mlir new file mode 100644 index 000000000000..5f33f650930e --- /dev/null +++ b/mlir/test/Dialect/Linalg/namedop_conversion.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s -linalg-named-op-conversion -split-input-file | FileCheck %s + +// CHECK-LABEL: @depthwise_conv +func @depthwise_conv(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]] + // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]] + // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor, tensor) outs(%[[INIT]] : tensor) + // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]] + %0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + + +// ----- + +// CHECK-LABEL: @depthwise_conv_q +func @depthwise_conv_q(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3 : i32, %arg4 : i32) -> tensor { + // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]] + // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]] + // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor, tensor, i32, i32) outs(%[[INIT]] : tensor) + // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]] + %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor, tensor, i32, i32) outs(%arg2 : tensor) -> tensor + return %0 : tensor +}