[mlir] Rename ShapeTypeConversion to ShapeBufferize

Once we have tensor_to_memref ops suitable for type materializations,
this pass can be split into a generic type conversion pattern.

Part of the refactor discussed in:
https://llvm.discourse.group/t/what-is-the-strategy-for-tensor-memref-conversion-bufferization/1938/17

Differential Revision: https://reviews.llvm.org/D89258
This commit is contained in:
Sean Silva 2020-10-12 12:23:45 -07:00
parent 9ca97cde85
commit 6b30fb7653
5 changed files with 17 additions and 22 deletions

View File

@ -43,9 +43,12 @@ std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
void populateShapeTypeConversionPatterns(
MLIRContext *ctx, BufferAssignmentTypeConverter &converter,
OwningRewritePatternList &patterns);
// Collects a set of patterns to replace tensors as inputs and outputs to shape
// operations with buffers. This only modifies the shape operations.
std::unique_ptr<FunctionPass> createShapeTensorToMemrefPass();
// Bufferizes shape dialect ops.
//
// Note that most shape dialect ops must be converted to std before
// bufferization happens, as they are intended to be bufferized at the std
// level.
std::unique_ptr<FunctionPass> createShapeBufferizePass();
//===----------------------------------------------------------------------===//
// Registration

View File

@ -22,8 +22,8 @@ def ShapeToShapeLowering : FunctionPass<"shape-to-shape-lowering"> {
}
// TODO: Generalize this to allow any type conversions desired.
def ShapeTensorToMemref : FunctionPass<"shape-tensor-to-memref"> {
let summary = "Replace tensors involving shape operations with memrefs";
let constructor = "mlir::createShapeTensorToMemrefPass()";
def ShapeBufferize : FunctionPass<"shape-bufferize"> {
let summary = "Bufferize the shape dialect.";
let constructor = "mlir::createShapeBufferizePass()";
}
#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES

View File

@ -1,16 +1,12 @@
//====----- ShapeTypeConversion.cpp - Shape Type Conversions ----*- C++-*--===//
//====----- Bufferize.cpp - Bufferization of shape ops ---------*- C++-*--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines patterns to convert types of inputs and outputs to shape
// operations to be memrefs instead of tensors.
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
@ -18,7 +14,6 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Bufferize.h"
using namespace mlir;
using namespace mlir::shape;
@ -53,8 +48,7 @@ public:
}
};
struct ShapeTensorToMemrefPass
: public ShapeTensorToMemrefBase<ShapeTensorToMemrefPass> {
struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
void runOnFunction() override {
MLIRContext &ctx = getContext();
@ -87,9 +81,9 @@ void mlir::populateShapeTypeConversionPatterns(
}
//===----------------------------------------------------------------------===//
// ShapeTensorToMemrefPass construction
// ShapeBufferizePass construction
//===----------------------------------------------------------------------===//
std::unique_ptr<FunctionPass> mlir::createShapeTensorToMemrefPass() {
return std::make_unique<ShapeTensorToMemrefPass>();
std::unique_ptr<FunctionPass> mlir::createShapeBufferizePass() {
return std::make_unique<ShapeBufferizePass>();
}

View File

@ -1,6 +1,6 @@
add_mlir_dialect_library(MLIRShapeOpsTransforms
Bufferize.cpp
RemoveShapeConstraints.cpp
ShapeTypeConversion.cpp
ShapeToShapeLowering.cpp
ADDITIONAL_HEADER_DIRS

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -split-input-file -shape-tensor-to-memref <%s | FileCheck %s
// RUN: mlir-opt -split-input-file -shape-bufferize <%s | FileCheck %s
// -----
// Check that shape.assuming returns a memref.
@ -14,5 +14,3 @@ func @shape_assuming_returns_memref() {
"test.sink"(%1) : (tensor<2xf16>) -> ()
return
}