diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 9beba34d08c5..ba6b75e41182 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -48,6 +48,7 @@ #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h" #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" +#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h" #include "mlir/Conversion/TosaToArith/TosaToArith.h" #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index d6b8b5dd357e..0c4aca31968d 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -708,6 +708,20 @@ def ConvertSPIRVToLLVM : Pass<"convert-spirv-to-llvm", "ModuleOp"> { let dependentDialects = ["LLVM::LLVMDialect"]; } +//===----------------------------------------------------------------------===// +// TensorToLinalg +//===----------------------------------------------------------------------===// + +def ConvertTensorToLinalg : Pass<"convert-tensor-to-linalg", "ModuleOp"> { + let summary = "Convert some Tensor dialect ops to Linalg dialect"; + let constructor = "mlir::createConvertTensorToLinalgPass()"; + let dependentDialects = [ + "arith::ArithmeticDialect", + "linalg::LinalgDialect", + ]; +} + + //===----------------------------------------------------------------------===// // TensorToSPIRV //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/TensorToLinalg/TensorToLinalg.h b/mlir/include/mlir/Conversion/TensorToLinalg/TensorToLinalg.h new file mode 100644 index 000000000000..dd5191f07896 --- /dev/null +++ b/mlir/include/mlir/Conversion/TensorToLinalg/TensorToLinalg.h @@ -0,0 +1,26 @@ +//===- TensorToLinalg.h - Tensor to Linalg Patterns -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides patterns to convert Tensor dialect to Linalg dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_TENSORTOLINALG_TENSORTOLINALG_H +#define MLIR_CONVERSION_TENSORTOLINALG_TENSORTOLINALG_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +/// Appends to a pattern list additional patterns for translating tensor ops +/// to Linalg ops. +void populateTensorToLinalgPatterns(RewritePatternSet &patterns); + +} // namespace mlir + +#endif // MLIR_CONVERSION_TENSORTOLINALG_TENSORTOLINALG_H diff --git a/mlir/include/mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h b/mlir/include/mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h new file mode 100644 index 000000000000..2f32179cd218 --- /dev/null +++ b/mlir/include/mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h @@ -0,0 +1,26 @@ +//===- TensorToLinalgPass.h - Tensor to Linalg Passes --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides passes to convert Tensor dialect to Linalg dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_TENSORTOLINALG_TENSORTOLINALGPASS_H +#define MLIR_CONVERSION_TENSORTOLINALG_TENSORTOLINALGPASS_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +class ModuleOp; + +/// Creates a pass to convert Tensor ops to Linalg ops. +std::unique_ptr> createConvertTensorToLinalgPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_TENSORTOLINALG_TENSORTOLINALGPASS_H diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index ccccd640378e..d0c0083dcf88 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -37,6 +37,7 @@ add_subdirectory(SCFToOpenMP) add_subdirectory(SCFToSPIRV) add_subdirectory(ShapeToStandard) add_subdirectory(SPIRVToLLVM) +add_subdirectory(TensorToLinalg) add_subdirectory(TensorToSPIRV) add_subdirectory(TosaToArith) add_subdirectory(TosaToLinalg) diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h index 5b630583e895..e05004061dd4 100644 --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -47,6 +47,10 @@ namespace func { class FuncDialect; } // namespace func +namespace linalg { +class LinalgDialect; +} // namespace linalg + namespace LLVM { class LLVMDialect; } // namespace LLVM diff --git a/mlir/lib/Conversion/TensorToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TensorToLinalg/CMakeLists.txt new file mode 100644 index 000000000000..5bd7e9e4b8de --- /dev/null +++ b/mlir/lib/Conversion/TensorToLinalg/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRTensorToLinalg + TensorToLinalg.cpp + TensorToLinalgPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg + ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRLinalg + MLIRSupport + MLIRTransformUtils + MLIRTensor + ) diff --git a/mlir/lib/Conversion/TensorToLinalg/TensorToLinalg.cpp b/mlir/lib/Conversion/TensorToLinalg/TensorToLinalg.cpp new file mode 100644 index 000000000000..08d4dd7bb192 --- /dev/null +++ b/mlir/lib/Conversion/TensorToLinalg/TensorToLinalg.cpp @@ -0,0 +1,31 @@ +//===- TensorToLinalg.cpp - Tensor to Linalg Patterns ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns to convert Tensor dialect to Linalg dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/TensorToLinalg/TensorToLinalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tensor-to-linalg-pattern" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Pattern population +//===----------------------------------------------------------------------===// + +void mlir::populateTensorToLinalgPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/Conversion/TensorToLinalg/TensorToLinalgPass.cpp b/mlir/lib/Conversion/TensorToLinalg/TensorToLinalgPass.cpp new file mode 100644 index 000000000000..8be029e8c5c2 --- /dev/null +++ b/mlir/lib/Conversion/TensorToLinalg/TensorToLinalgPass.cpp @@ -0,0 +1,47 @@ +//===- TensorToLinalgPass.cpp - Tensor to Linalg Passes -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert Tensor dialect to Linalg dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" +#include "../PassDetail.h" +#include "mlir/Conversion/TensorToLinalg/TensorToLinalg.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +using namespace mlir; + +namespace { +/// A pass converting MLIR Tensor operations into the Linalg dialect. +class ConvertTensorToLinalgPass + : public ConvertTensorToLinalgBase { + void runOnOperation() override { + auto &context = getContext(); + ConversionTarget target(context); + target.addLegalDialect(); + target.addIllegalOp(); + + RewritePatternSet patterns(&context); + populateTensorToLinalgPatterns(patterns); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::createConvertTensorToLinalgPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir b/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir new file mode 100644 index 000000000000..dcb5f1525e30 --- /dev/null +++ b/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt -split-input-file -convert-tensor-to-linalg -cse -verify-diagnostics %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// tensor.pad +//===----------------------------------------------------------------------===// +// CHECK-LABEL: func @generalize_pad_tensor_static_shape( +// CHECK-SAME: %[[IN:.*]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> { +// CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[INIT:.*]] = linalg.init_tensor [1, 32, 32, 1] : tensor<1x32x32x1xf32> +// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x32x32x1xf32>) -> tensor<1x32x32x1xf32> +// CHECK: %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]][0, 2, 2, 0] [1, 28, 28, 1] [1, 1, 1, 1] : tensor<1x28x28x1xf32> into tensor<1x32x32x1xf32> +// CHECK: return %[[PADDED]] : tensor<1x32x32x1xf32> +func.func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): + tensor.yield %cst : f32 + } : tensor<1x28x28x1xf32> to tensor<1x32x32x1xf32> + return %0 : tensor<1x32x32x1xf32> +} + +// CHECK-LABEL: func @generalize_pad_tensor_dynamic_shape( +// CHECK-SAME: %[[IN:.*]]: tensor<4x?x2x?xf32>, +// CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<4x?x?x?xf32> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32> +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[OUT_DIM2:.*]] = arith.addi %[[C2]], %[[OFFSET]] : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[DIM3:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32> +// CHECK: %[[OUT_DIM3:.*]] = arith.addi %[[DIM3]], %[[OFFSET]] : index +// CHECK: %[[INIT:.*]] = linalg.init_tensor [4, %[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]] : tensor<4x?x?x?xf32> +// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<4x?x?x?xf32>) -> tensor<4x?x?x?xf32> +// CHECK: %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[OFFSET]], %[[C0]]] [4, %[[DIM1]], 2, %[[DIM3]]] [1, 1, 1, 1] : tensor<4x?x2x?xf32> into tensor<4x?x?x?xf32> +// CHECK: return %[[PADDED]] : tensor<4x?x?x?xf32> +// CHECK: } +func.func @generalize_pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tensor<4x?x?x?xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.0 : f32 + %out = tensor.pad %arg0 low[%c0, %c0, %arg1, %c0] high[%c0, %c0, %c0, %arg1] { + ^bb0(%gen_arg1: index, %gen_arg2: index, %gen_arg3: index, %gen_arg4: index): + tensor.yield %cst : f32 + } : tensor<4x?x2x?xf32> to tensor<4x?x?x?xf32> + return %out : tensor<4x?x?x?xf32> +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index ea947228155d..9cf7c5dc776b 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2487,6 +2487,7 @@ cc_library( ":SCFToStandard", ":SPIRVToLLVM", ":ShapeToStandard", + ":TensorToLinalg", ":TensorToSPIRV", ":TosaToArith", ":TosaToLinalg", @@ -4677,6 +4678,36 @@ cc_library( ], ) +cc_library( + name = "TensorToLinalg", + srcs = glob([ + "lib/Conversion/TensorToLinalg/*.cpp", + "lib/Conversion/TensorToLinalg/*.h", + ]) + [":ConversionPassDetail"], + hdrs = glob([ + "include/mlir/Conversion/TensorToLinalg/*.h", + ]), + includes = [ + "include", + "lib/Conversion/TensorToLinalg", + ], + deps = [ + ":ArithmeticDialect", + ":ConversionPassIncGen", + ":FuncDialect", + ":IR", + ":LinalgOps", + ":LinalgTransforms", + ":Pass", + ":Support", + ":TensorDialect", + ":Transforms", + ":VectorOps", + "//llvm:Support", + ], +) + + cc_library( name = "TensorToSPIRV", srcs = glob([