forked from OSchip/llvm-project
[mlir] Add TensorToLinalgPass
This pass is to handle computationally complex operations like tensor.pad which are not simply lowered to the exact same operation in the memref dialect. Differential Revision: https://reviews.llvm.org/D125384
This commit is contained in:
parent
4de9a8ae3f
commit
1dce51b888
|
@ -48,6 +48,7 @@
|
|||
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h"
|
||||
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
|
||||
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
||||
#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h"
|
||||
#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h"
|
||||
#include "mlir/Conversion/TosaToArith/TosaToArith.h"
|
||||
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
|
||||
|
|
|
@ -708,6 +708,20 @@ def ConvertSPIRVToLLVM : Pass<"convert-spirv-to-llvm", "ModuleOp"> {
|
|||
let dependentDialects = ["LLVM::LLVMDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorToLinalg
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertTensorToLinalg : Pass<"convert-tensor-to-linalg", "ModuleOp"> {
|
||||
let summary = "Convert some Tensor dialect ops to Linalg dialect";
|
||||
let constructor = "mlir::createConvertTensorToLinalgPass()";
|
||||
let dependentDialects = [
|
||||
"arith::ArithmeticDialect",
|
||||
"linalg::LinalgDialect",
|
||||
];
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorToSPIRV
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
//===- TensorToLinalg.h - Tensor to Linalg Patterns -------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Provides patterns to convert Tensor dialect to Linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_CONVERSION_TENSORTOLINALG_TENSORTOLINALG_H
|
||||
#define MLIR_CONVERSION_TENSORTOLINALG_TENSORTOLINALG_H
|
||||
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Appends to a pattern list additional patterns for translating tensor ops
|
||||
/// to Linalg ops.
|
||||
void populateTensorToLinalgPatterns(RewritePatternSet &patterns);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_TENSORTOLINALG_TENSORTOLINALG_H
|
|
@ -0,0 +1,26 @@
|
|||
//===- TensorToLinalgPass.h - Tensor to Linalg Passes --------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Provides passes to convert Tensor dialect to Linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_CONVERSION_TENSORTOLINALG_TENSORTOLINALGPASS_H
|
||||
#define MLIR_CONVERSION_TENSORTOLINALG_TENSORTOLINALGPASS_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
class ModuleOp;
|
||||
|
||||
/// Creates a pass to convert Tensor ops to Linalg ops.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTensorToLinalgPass();
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_TENSORTOLINALG_TENSORTOLINALGPASS_H
|
|
@ -37,6 +37,7 @@ add_subdirectory(SCFToOpenMP)
|
|||
add_subdirectory(SCFToSPIRV)
|
||||
add_subdirectory(ShapeToStandard)
|
||||
add_subdirectory(SPIRVToLLVM)
|
||||
add_subdirectory(TensorToLinalg)
|
||||
add_subdirectory(TensorToSPIRV)
|
||||
add_subdirectory(TosaToArith)
|
||||
add_subdirectory(TosaToLinalg)
|
||||
|
|
|
@ -47,6 +47,10 @@ namespace func {
|
|||
class FuncDialect;
|
||||
} // namespace func
|
||||
|
||||
namespace linalg {
|
||||
class LinalgDialect;
|
||||
} // namespace linalg
|
||||
|
||||
namespace LLVM {
|
||||
class LLVMDialect;
|
||||
} // namespace LLVM
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
add_mlir_conversion_library(MLIRTensorToLinalg
|
||||
TensorToLinalg.cpp
|
||||
TensorToLinalgPass.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
|
||||
|
||||
DEPENDS
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRLinalg
|
||||
MLIRSupport
|
||||
MLIRTransformUtils
|
||||
MLIRTensor
|
||||
)
|
|
@ -0,0 +1,31 @@
|
|||
//===- TensorToLinalg.cpp - Tensor to Linalg Patterns ---------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements patterns to convert Tensor dialect to Linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/TensorToLinalg/TensorToLinalg.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "tensor-to-linalg-pattern"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern population
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::populateTensorToLinalgPatterns(RewritePatternSet &patterns) {
|
||||
patterns.add<mlir::linalg::GeneralizePadOpPattern>(patterns.getContext());
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
//===- TensorToLinalgPass.cpp - Tensor to Linalg Passes -------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements a pass to convert Tensor dialect to Linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h"
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Conversion/TensorToLinalg/TensorToLinalg.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// A pass converting MLIR Tensor operations into the Linalg dialect.
|
||||
class ConvertTensorToLinalgPass
|
||||
: public ConvertTensorToLinalgBase<ConvertTensorToLinalgPass> {
|
||||
void runOnOperation() override {
|
||||
auto &context = getContext();
|
||||
ConversionTarget target(context);
|
||||
target.addLegalDialect<mlir::arith::ArithmeticDialect,
|
||||
mlir::linalg::LinalgDialect,
|
||||
mlir::tensor::TensorDialect>();
|
||||
target.addIllegalOp<mlir::tensor::PadOp>();
|
||||
|
||||
RewritePatternSet patterns(&context);
|
||||
populateTensorToLinalgPatterns(patterns);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::createConvertTensorToLinalgPass() {
|
||||
return std::make_unique<ConvertTensorToLinalgPass>();
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
// RUN: mlir-opt -split-input-file -convert-tensor-to-linalg -cse -verify-diagnostics %s | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tensor.pad
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CHECK-LABEL: func @generalize_pad_tensor_static_shape(
|
||||
// CHECK-SAME: %[[IN:.*]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [1, 32, 32, 1] : tensor<1x32x32x1xf32>
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x32x32x1xf32>) -> tensor<1x32x32x1xf32>
|
||||
// CHECK: %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]][0, 2, 2, 0] [1, 28, 28, 1] [1, 1, 1, 1] : tensor<1x28x28x1xf32> into tensor<1x32x32x1xf32>
|
||||
// CHECK: return %[[PADDED]] : tensor<1x32x32x1xf32>
|
||||
func.func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%0 = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
|
||||
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
|
||||
tensor.yield %cst : f32
|
||||
} : tensor<1x28x28x1xf32> to tensor<1x32x32x1xf32>
|
||||
return %0 : tensor<1x32x32x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @generalize_pad_tensor_dynamic_shape(
|
||||
// CHECK-SAME: %[[IN:.*]]: tensor<4x?x2x?xf32>,
|
||||
// CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<4x?x?x?xf32> {
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[DIM1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32>
|
||||
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[OUT_DIM2:.*]] = arith.addi %[[C2]], %[[OFFSET]] : index
|
||||
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
|
||||
// CHECK: %[[DIM3:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32>
|
||||
// CHECK: %[[OUT_DIM3:.*]] = arith.addi %[[DIM3]], %[[OFFSET]] : index
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [4, %[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]] : tensor<4x?x?x?xf32>
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<4x?x?x?xf32>) -> tensor<4x?x?x?xf32>
|
||||
// CHECK: %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[OFFSET]], %[[C0]]] [4, %[[DIM1]], 2, %[[DIM3]]] [1, 1, 1, 1] : tensor<4x?x2x?xf32> into tensor<4x?x?x?xf32>
|
||||
// CHECK: return %[[PADDED]] : tensor<4x?x?x?xf32>
|
||||
// CHECK: }
|
||||
func.func @generalize_pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tensor<4x?x?x?xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst = arith.constant 0.0 : f32
|
||||
%out = tensor.pad %arg0 low[%c0, %c0, %arg1, %c0] high[%c0, %c0, %c0, %arg1] {
|
||||
^bb0(%gen_arg1: index, %gen_arg2: index, %gen_arg3: index, %gen_arg4: index):
|
||||
tensor.yield %cst : f32
|
||||
} : tensor<4x?x2x?xf32> to tensor<4x?x?x?xf32>
|
||||
return %out : tensor<4x?x?x?xf32>
|
||||
}
|
|
@ -2487,6 +2487,7 @@ cc_library(
|
|||
":SCFToStandard",
|
||||
":SPIRVToLLVM",
|
||||
":ShapeToStandard",
|
||||
":TensorToLinalg",
|
||||
":TensorToSPIRV",
|
||||
":TosaToArith",
|
||||
":TosaToLinalg",
|
||||
|
@ -4677,6 +4678,36 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "TensorToLinalg",
|
||||
srcs = glob([
|
||||
"lib/Conversion/TensorToLinalg/*.cpp",
|
||||
"lib/Conversion/TensorToLinalg/*.h",
|
||||
]) + [":ConversionPassDetail"],
|
||||
hdrs = glob([
|
||||
"include/mlir/Conversion/TensorToLinalg/*.h",
|
||||
]),
|
||||
includes = [
|
||||
"include",
|
||||
"lib/Conversion/TensorToLinalg",
|
||||
],
|
||||
deps = [
|
||||
":ArithmeticDialect",
|
||||
":ConversionPassIncGen",
|
||||
":FuncDialect",
|
||||
":IR",
|
||||
":LinalgOps",
|
||||
":LinalgTransforms",
|
||||
":Pass",
|
||||
":Support",
|
||||
":TensorDialect",
|
||||
":Transforms",
|
||||
":VectorOps",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
cc_library(
|
||||
name = "TensorToSPIRV",
|
||||
srcs = glob([
|
||||
|
|
Loading…
Reference in New Issue