[mlir] Add TensorToLinalgPass

This pass is to handle computationally complex operations like
tensor.pad which are not simply lowered to the exact same operation in
the memref dialect.

Differential Revision: https://reviews.llvm.org/D125384
This commit is contained in:
Tres Popp 2022-05-11 14:10:12 +02:00
parent 4de9a8ae3f
commit 1dce51b888
11 changed files with 247 additions and 0 deletions

View File

@ -48,6 +48,7 @@
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h"
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h"
#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h"
#include "mlir/Conversion/TosaToArith/TosaToArith.h"
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"

View File

@ -708,6 +708,20 @@ def ConvertSPIRVToLLVM : Pass<"convert-spirv-to-llvm", "ModuleOp"> {
let dependentDialects = ["LLVM::LLVMDialect"];
}
//===----------------------------------------------------------------------===//
// TensorToLinalg
//===----------------------------------------------------------------------===//
def ConvertTensorToLinalg : Pass<"convert-tensor-to-linalg", "ModuleOp"> {
let summary = "Convert some Tensor dialect ops to Linalg dialect";
let constructor = "mlir::createConvertTensorToLinalgPass()";
let dependentDialects = [
"arith::ArithmeticDialect",
"linalg::LinalgDialect",
];
}
//===----------------------------------------------------------------------===//
// TensorToSPIRV
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,26 @@
//===- TensorToLinalg.h - Tensor to Linalg Patterns -------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Provides patterns to convert Tensor dialect to Linalg dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_TENSORTOLINALG_TENSORTOLINALG_H
#define MLIR_CONVERSION_TENSORTOLINALG_TENSORTOLINALG_H
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
/// Appends to a pattern list additional patterns for translating tensor ops
/// to Linalg ops.
void populateTensorToLinalgPatterns(RewritePatternSet &patterns);
} // namespace mlir
#endif // MLIR_CONVERSION_TENSORTOLINALG_TENSORTOLINALG_H

View File

@ -0,0 +1,26 @@
//===- TensorToLinalgPass.h - Tensor to Linalg Passes --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Provides passes to convert Tensor dialect to Linalg dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_TENSORTOLINALG_TENSORTOLINALGPASS_H
#define MLIR_CONVERSION_TENSORTOLINALG_TENSORTOLINALGPASS_H
#include "mlir/Pass/Pass.h"
namespace mlir {
class ModuleOp;
/// Creates a pass to convert Tensor ops to Linalg ops.
std::unique_ptr<OperationPass<ModuleOp>> createConvertTensorToLinalgPass();
} // namespace mlir
#endif // MLIR_CONVERSION_TENSORTOLINALG_TENSORTOLINALGPASS_H

View File

@ -37,6 +37,7 @@ add_subdirectory(SCFToOpenMP)
add_subdirectory(SCFToSPIRV)
add_subdirectory(ShapeToStandard)
add_subdirectory(SPIRVToLLVM)
add_subdirectory(TensorToLinalg)
add_subdirectory(TensorToSPIRV)
add_subdirectory(TosaToArith)
add_subdirectory(TosaToLinalg)

View File

@ -47,6 +47,10 @@ namespace func {
class FuncDialect;
} // namespace func
namespace linalg {
class LinalgDialect;
} // namespace linalg
namespace LLVM {
class LLVMDialect;
} // namespace LLVM

View File

@ -0,0 +1,19 @@
add_mlir_conversion_library(MLIRTensorToLinalg
TensorToLinalg.cpp
TensorToLinalgPass.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
DEPENDS
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRLinalg
MLIRSupport
MLIRTransformUtils
MLIRTensor
)

View File

@ -0,0 +1,31 @@
//===- TensorToLinalg.cpp - Tensor to Linalg Patterns ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert Tensor dialect to Linalg dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/TensorToLinalg/TensorToLinalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "tensor-to-linalg-pattern"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
void mlir::populateTensorToLinalgPatterns(RewritePatternSet &patterns) {
patterns.add<mlir::linalg::GeneralizePadOpPattern>(patterns.getContext());
}

View File

@ -0,0 +1,47 @@
//===- TensorToLinalgPass.cpp - Tensor to Linalg Passes -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert Tensor dialect to Linalg dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h"
#include "../PassDetail.h"
#include "mlir/Conversion/TensorToLinalg/TensorToLinalg.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
using namespace mlir;
namespace {
/// A pass converting MLIR Tensor operations into the Linalg dialect.
class ConvertTensorToLinalgPass
: public ConvertTensorToLinalgBase<ConvertTensorToLinalgPass> {
void runOnOperation() override {
auto &context = getContext();
ConversionTarget target(context);
target.addLegalDialect<mlir::arith::ArithmeticDialect,
mlir::linalg::LinalgDialect,
mlir::tensor::TensorDialect>();
target.addIllegalOp<mlir::tensor::PadOp>();
RewritePatternSet patterns(&context);
populateTensorToLinalgPatterns(patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertTensorToLinalgPass() {
return std::make_unique<ConvertTensorToLinalgPass>();
}

View File

@ -0,0 +1,47 @@
// RUN: mlir-opt -split-input-file -convert-tensor-to-linalg -cse -verify-diagnostics %s | FileCheck %s
//===----------------------------------------------------------------------===//
// tensor.pad
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @generalize_pad_tensor_static_shape(
// CHECK-SAME: %[[IN:.*]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
// CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[INIT:.*]] = linalg.init_tensor [1, 32, 32, 1] : tensor<1x32x32x1xf32>
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x32x32x1xf32>) -> tensor<1x32x32x1xf32>
// CHECK: %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]][0, 2, 2, 0] [1, 28, 28, 1] [1, 1, 1, 1] : tensor<1x28x28x1xf32> into tensor<1x32x32x1xf32>
// CHECK: return %[[PADDED]] : tensor<1x32x32x1xf32>
func.func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
tensor.yield %cst : f32
} : tensor<1x28x28x1xf32> to tensor<1x32x32x1xf32>
return %0 : tensor<1x32x32x1xf32>
}
// CHECK-LABEL: func @generalize_pad_tensor_dynamic_shape(
// CHECK-SAME: %[[IN:.*]]: tensor<4x?x2x?xf32>,
// CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<4x?x?x?xf32> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32>
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[OUT_DIM2:.*]] = arith.addi %[[C2]], %[[OFFSET]] : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[DIM3:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32>
// CHECK: %[[OUT_DIM3:.*]] = arith.addi %[[DIM3]], %[[OFFSET]] : index
// CHECK: %[[INIT:.*]] = linalg.init_tensor [4, %[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]] : tensor<4x?x?x?xf32>
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<4x?x?x?xf32>) -> tensor<4x?x?x?xf32>
// CHECK: %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[OFFSET]], %[[C0]]] [4, %[[DIM1]], 2, %[[DIM3]]] [1, 1, 1, 1] : tensor<4x?x2x?xf32> into tensor<4x?x?x?xf32>
// CHECK: return %[[PADDED]] : tensor<4x?x?x?xf32>
// CHECK: }
func.func @generalize_pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tensor<4x?x?x?xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
%out = tensor.pad %arg0 low[%c0, %c0, %arg1, %c0] high[%c0, %c0, %c0, %arg1] {
^bb0(%gen_arg1: index, %gen_arg2: index, %gen_arg3: index, %gen_arg4: index):
tensor.yield %cst : f32
} : tensor<4x?x2x?xf32> to tensor<4x?x?x?xf32>
return %out : tensor<4x?x?x?xf32>
}

View File

@ -2487,6 +2487,7 @@ cc_library(
":SCFToStandard",
":SPIRVToLLVM",
":ShapeToStandard",
":TensorToLinalg",
":TensorToSPIRV",
":TosaToArith",
":TosaToLinalg",
@ -4677,6 +4678,36 @@ cc_library(
],
)
cc_library(
name = "TensorToLinalg",
srcs = glob([
"lib/Conversion/TensorToLinalg/*.cpp",
"lib/Conversion/TensorToLinalg/*.h",
]) + [":ConversionPassDetail"],
hdrs = glob([
"include/mlir/Conversion/TensorToLinalg/*.h",
]),
includes = [
"include",
"lib/Conversion/TensorToLinalg",
],
deps = [
":ArithmeticDialect",
":ConversionPassIncGen",
":FuncDialect",
":IR",
":LinalgOps",
":LinalgTransforms",
":Pass",
":Support",
":TensorDialect",
":Transforms",
":VectorOps",
"//llvm:Support",
],
)
cc_library(
name = "TensorToSPIRV",
srcs = glob([