forked from OSchip/llvm-project
[mlir][shape] add outline-shape-computation pass
Add outline-shape-computation pass. This pass his pass outlines the shape computation part in high level IR by adding shape.func and populate corresponding mapping information into ShapeMappingAnalysis. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D131810
This commit is contained in:
parent
8bcf22e318
commit
9f77909a5e
|
@ -0,0 +1,60 @@
|
|||
//===- ShapeMappingAnalysis.h - Preserve shape Info ------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_SHAPE_ANALYSIS_SHAPEMAPPINGANALYSIS_H_
|
||||
#define MLIR_DIALECT_SHAPE_ANALYSIS_SHAPEMAPPINGANALYSIS_H_
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace shape {
|
||||
|
||||
/// ShapeMappingValue works as the value of ShapeMappingAnalysis table, where
|
||||
/// `funcSymbol` is the symbol of mapping function, and `inputs` are the actual
|
||||
/// parameters for the function.
|
||||
struct ShapeMappingValue {
|
||||
ShapeMappingValue() = default;
|
||||
ShapeMappingValue(FlatSymbolRefAttr symbol, llvm::SmallVector<Value> &&inps)
|
||||
: funcSymbol(symbol), inputs(inps) {}
|
||||
|
||||
FlatSymbolRefAttr funcSymbol;
|
||||
llvm::SmallVector<Value> inputs;
|
||||
};
|
||||
|
||||
/// ShapeMappingAnalysis is used together with OutlineShapeComputationPass to
|
||||
/// preserve Value and corresponding shape function / arguments mapping
|
||||
/// information
|
||||
struct ShapeMappingAnalysis {
|
||||
ShapeMappingAnalysis(Operation *op) : operation(op) { (void)operation; }
|
||||
|
||||
/// Dumps the shape mapping information to the given stream.
|
||||
void print(raw_ostream &os) const {
|
||||
os << "// ---- Shape Mapping Information -----\n";
|
||||
for (const auto &it : shapeMapping) {
|
||||
const ShapeMappingValue &mappingValue = it.second;
|
||||
os << "// Shape for " << it.first << " :: " << mappingValue.funcSymbol;
|
||||
llvm::interleaveComma(mappingValue.inputs, os << "(");
|
||||
os << ")\n";
|
||||
}
|
||||
}
|
||||
|
||||
llvm::DenseMap<Value, ShapeMappingValue> shapeMapping;
|
||||
|
||||
private:
|
||||
Operation *operation;
|
||||
};
|
||||
|
||||
} // namespace shape
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_SHAPE_ANALYSIS_SHAPEMAPPINGANALYSIS_H_
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
namespace mlir {
|
||||
class ConversionTarget;
|
||||
class ModuleOp;
|
||||
class TypeConverter;
|
||||
namespace func {
|
||||
class FuncOp;
|
||||
|
@ -53,6 +54,10 @@ std::unique_ptr<OperationPass<func::FuncOp>> createRemoveShapeConstraintsPass();
|
|||
// level.
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createShapeBufferizePass();
|
||||
|
||||
/// Outline the shape computation part by adding shape.func and populate
|
||||
/// conrresponding mapping infomation into ShapeMappingAnalysis.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createOutlineShapeComputationPass();
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -11,6 +11,88 @@
|
|||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def OutlineShapeComputation : Pass<"outline-shape-computation", "ModuleOp"> {
|
||||
let summary = "Using shape.func to preserve shape computation";
|
||||
let description = [{
|
||||
This pass outlines the shape computation part in high level IR by adding
|
||||
shape.func and populate corresponding mapping infoemation into
|
||||
ShapeMappingAnalysis. The shape computation part is usually introduced by
|
||||
shape reification, and each single dynamic shape is denoted by shape.with_shape.
|
||||
|
||||
There're two main reasons this shape-outline pass is needed:
|
||||
1. Many passes don't take shape reification part into consideration.
|
||||
Therefore we need to "remove" the shape reification part temporarily for
|
||||
these passes.
|
||||
2. Sometimes we cannot redo shape reification after converting from dialect
|
||||
A to dialect B. Because op-level shape reification is only implemented
|
||||
on A.
|
||||
|
||||
Input:
|
||||
|
||||
```mlir
|
||||
func.func @main(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) ->
|
||||
tensor<?x4x?xf32> {
|
||||
%c2 = arith.constant 2 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%0 = shape.shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
|
||||
%1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index
|
||||
%2 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
|
||||
%3 = shape.with_shape %2, %0 : tensor<?x4x?xf32>, tensor<3xindex>
|
||||
%4 = shape.value_of %3 : tensor<?x4x?xf32>
|
||||
%5 = "test.concat"(%4, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>,
|
||||
tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
|
||||
%6 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index
|
||||
%7 = arith.addi %6, %c2 : index
|
||||
%8 = shape.from_extents %7, %c4, %1 : index, index, index
|
||||
%9 = shape.with_shape %5, %8 : tensor<?x4x?xf32>, !shape.shape
|
||||
%10 = shape.value_of %9 : tensor<?x4x?xf32>
|
||||
return %10 : tensor<?x4x?xf32>
|
||||
}
|
||||
```
|
||||
|
||||
Output
|
||||
```mlir
|
||||
func.func @main(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) ->
|
||||
tensor<?x4x?xf32> {
|
||||
%0 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
|
||||
%1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>,
|
||||
tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
|
||||
return %1 : tensor<?x4x?xf32>
|
||||
}
|
||||
shape.func private @shape_cal_1(%arg0: tensor<?x4x?xf32>) -> !shape.shape {
|
||||
%c2 = arith.constant 2 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%0 = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
|
||||
%1 = get_extent %0, %c2 : tensor<3xindex>, index -> index
|
||||
%2 = get_extent %0, %c0 : tensor<3xindex>, index -> index
|
||||
%3 = arith.addi %2, %c2 : index
|
||||
%4 = from_extents %3, %c4, %1 : index, index, index
|
||||
return %4 : !shape.shape
|
||||
}
|
||||
shape.func private @shape_cal_0(%arg0: tensor<?x4x?xf32>) -> tensor<3xindex> {
|
||||
%0 = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
|
||||
return %0 : tensor<3xindex>
|
||||
}
|
||||
```
|
||||
|
||||
For the above example, the shape computation is inlined in the input IR,
|
||||
which is used for two values' (test.abs and test.concat) shape. And the shape
|
||||
compuatation part is outlined in the output IR.
|
||||
|
||||
And the shape mapping infomation will be:
|
||||
|
||||
```
|
||||
// ---- Shape Mapping Infomation -----
|
||||
// - Shape for: %0 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32> :: @shape_cal_0(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)
|
||||
// - Shape for: %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32> :: @shape_cal_1(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)
|
||||
```
|
||||
}];
|
||||
let constructor = "mlir::createOutlineShapeComputationPass()";
|
||||
let dependentDialects = ["shape::ShapeDialect"];
|
||||
}
|
||||
|
||||
def RemoveShapeConstraints : Pass<"remove-shape-constraints", "func::FuncOp"> {
|
||||
let summary = "Replace all cstr_ ops with a true witness";
|
||||
let constructor = "mlir::createRemoveShapeConstraintsPass()";
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
add_mlir_dialect_library(MLIRShapeOpsTransforms
|
||||
BufferizableOpInterfaceImpl.cpp
|
||||
Bufferize.cpp
|
||||
OutlineShapeComputation.cpp
|
||||
RemoveShapeConstraints.cpp
|
||||
ShapeToShapeLowering.cpp
|
||||
|
||||
|
|
|
@ -0,0 +1,318 @@
|
|||
//====----- OutlineShapeComputation.cpp -----------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include <queue>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace mlir {
|
||||
#define GEN_PASS_DEF_OUTLINESHAPECOMPUTATION
|
||||
#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
|
||||
} // namespace mlir
|
||||
|
||||
#define DEBUG_TYPE "outline-shape-computation"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
// A Value is an input of the cluster if it is an operand of an operation in the
|
||||
// cluster and its defining operation is not in the cluster.
|
||||
SmallVector<Value, 4>
|
||||
getInputsOfCluster(const llvm::SmallVector<Operation *, 8> &cluster) {
|
||||
SmallVector<Value, 4> inputs;
|
||||
llvm::SmallDenseSet<Value> inputSet;
|
||||
llvm::SmallDenseSet<Operation *> opSet;
|
||||
for (Operation *op : cluster) {
|
||||
bool inserted = opSet.insert(op).second;
|
||||
(void)inserted;
|
||||
assert(inserted && "cluster contains duplicate operations");
|
||||
}
|
||||
|
||||
for (Operation *op : cluster) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
Operation *operandOp = operand.getDefiningOp();
|
||||
if (opSet.find(operandOp) != opSet.end()) {
|
||||
// Skip if defining op is in the cluster.
|
||||
continue;
|
||||
}
|
||||
if (inputSet.insert(operand).second)
|
||||
inputs.push_back(operand);
|
||||
}
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
// Create a shape.func representing the shape computation for `shape`.
|
||||
std::pair<shape::FuncOp, SmallVector<Value>>
|
||||
createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster,
|
||||
Value shape, StringRef fnName, Location loc) {
|
||||
SmallVector<Value, 4> inputs = getInputsOfCluster(cluster);
|
||||
auto fnType =
|
||||
cluster.empty()
|
||||
? b.getFunctionType(shape.getType(), shape.getType())
|
||||
: b.getFunctionType(ValueRange(inputs).getTypes(), shape.getType());
|
||||
shape::FuncOp fnOp = b.create<shape::FuncOp>(loc, fnName, fnType);
|
||||
Block *block = fnOp.addEntryBlock();
|
||||
b.setInsertionPoint(block, block->end());
|
||||
BlockAndValueMapping bvm;
|
||||
if (cluster.empty()) {
|
||||
bvm.map(shape, fnOp.getArgument(0));
|
||||
} else {
|
||||
for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments()))
|
||||
bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg));
|
||||
}
|
||||
|
||||
for (Operation *op : cluster)
|
||||
b.clone(*op, bvm);
|
||||
llvm::SmallVector<Value, 4> fnReturns;
|
||||
fnReturns.push_back(bvm.lookupOrDefault(shape));
|
||||
|
||||
b.create<shape::ReturnOp>(loc, fnReturns);
|
||||
fnOp.setPrivate();
|
||||
return std::make_pair(fnOp, inputs);
|
||||
}
|
||||
|
||||
// The operations in the cluster might be unsorted, which could be inconvenient
|
||||
// when creating shape.func op.
|
||||
DenseMap<Value, SmallVector<Operation *, 8>>
|
||||
getOrderedClusters(const DenseMap<Value, DenseSet<Operation *>> &clusters,
|
||||
func::FuncOp funcOp) {
|
||||
// Compute all clusters that each operation is in
|
||||
DenseMap<Operation *, SmallVector<Value>> op2Shapes;
|
||||
for (const auto &it : clusters) {
|
||||
Value shape = it.first;
|
||||
const DenseSet<Operation *> &cluster = it.second;
|
||||
for (Operation *cOp : cluster)
|
||||
op2Shapes[cOp].push_back(shape);
|
||||
}
|
||||
|
||||
// Iterate through all operations in order. Get all the clusters `cOp` belongs
|
||||
// to and construct the new ordered cluster as it traverses.
|
||||
DenseMap<Value, SmallVector<Operation *, 8>> orderedClusters;
|
||||
funcOp.walk([&](Operation *op) {
|
||||
auto it = op2Shapes.find(op);
|
||||
if (it != op2Shapes.end()) {
|
||||
Operation *cOp = it->first;
|
||||
for (Value shape : it->second)
|
||||
orderedClusters[shape].push_back(cOp);
|
||||
}
|
||||
});
|
||||
|
||||
return orderedClusters;
|
||||
}
|
||||
|
||||
void constructShapeFunc(
|
||||
const std::vector<shape::WithOp> &allWithOps, MLIRContext *context,
|
||||
DenseMap<Value, SmallVector<Operation *, 8>> &clusters,
|
||||
SymbolTable &symbolTable,
|
||||
DenseMap<Value, shape::ShapeMappingValue> &dynShape2ShapeFunc,
|
||||
func::FuncOp funcOp, shape::ShapeMappingAnalysis &shapeMappingAnalysis) {
|
||||
std::string shapeCalculationNamePrefix = "shape_cal_";
|
||||
int shapeCalculationNameIdx = 0;
|
||||
OpBuilder builder(context);
|
||||
|
||||
// Construct a shape function
|
||||
for (shape::WithOp withOp : allWithOps) {
|
||||
Value value = withOp.getOperand();
|
||||
Value shape = withOp.getShape();
|
||||
RankedTensorType rankedType = value.getType().dyn_cast<RankedTensorType>();
|
||||
if (rankedType == nullptr)
|
||||
continue;
|
||||
|
||||
const SmallVector<Operation *, 8> &cluster = clusters[shape];
|
||||
shape::ShapeMappingValue shapeMappingValue;
|
||||
auto it = dynShape2ShapeFunc.find(shape);
|
||||
if (it == dynShape2ShapeFunc.end()) {
|
||||
std::string name = shapeCalculationNamePrefix +
|
||||
std::to_string(shapeCalculationNameIdx++);
|
||||
Location loc = value.getLoc();
|
||||
builder.setInsertionPointAfter(funcOp);
|
||||
auto pair = createFuncFromCluster(builder, cluster, shape, name, loc);
|
||||
const SmallVector<Value> &inputs = pair.second;
|
||||
shape::FuncOp shapeFuncOp = pair.first;
|
||||
StringAttr insertedName = symbolTable.insert(shapeFuncOp);
|
||||
auto symbol = FlatSymbolRefAttr::get(context, insertedName);
|
||||
|
||||
shapeMappingValue.funcSymbol = symbol;
|
||||
shapeMappingValue.inputs = inputs;
|
||||
} else {
|
||||
shapeMappingValue = it->second;
|
||||
}
|
||||
dynShape2ShapeFunc[shape] = shapeMappingValue;
|
||||
shapeMappingAnalysis.shapeMapping.insert(
|
||||
std::make_pair(value, shapeMappingValue));
|
||||
}
|
||||
}
|
||||
|
||||
struct OutlineShapeComputationPass
|
||||
: public impl::OutlineShapeComputationBase<OutlineShapeComputationPass> {
|
||||
|
||||
void runOnOperation() override;
|
||||
|
||||
private:
|
||||
bool calOnlyUsedByWithShapesRecursively(Operation *op, Value prevOutput);
|
||||
|
||||
void getClusterFromValue(Value shape,
|
||||
DenseMap<Value, DenseSet<Operation *>> &clusters);
|
||||
|
||||
DenseMap<Value, SmallVector<Operation *, 8>>
|
||||
constructClustersForEachShape(const std::vector<shape::WithOp> &allWithOps,
|
||||
func::FuncOp funcOp);
|
||||
|
||||
DenseSet<Operation *> onlyUsedByWithShapes;
|
||||
};
|
||||
|
||||
class TensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
|
||||
using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tensor::DimOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto shapeOf =
|
||||
rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getSource());
|
||||
rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
|
||||
op.getIndex());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void OutlineShapeComputationPass::runOnOperation() {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
SymbolTable symbolTable(moduleOp);
|
||||
DenseMap<Value, shape::ShapeMappingValue> dynShape2ShapeFunc;
|
||||
auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>();
|
||||
// TODO: This is as we populate this analysis during a pass that mutates. This
|
||||
// pass currently requires 1 single module being compiled.
|
||||
shapeMappingAnalysis.shapeMapping.clear();
|
||||
markAnalysesPreserved<shape::ShapeMappingAnalysis>();
|
||||
|
||||
moduleOp.walk([&](func::FuncOp funcOp) {
|
||||
MLIRContext *context = funcOp.getContext();
|
||||
RewritePatternSet prevPatterns(context);
|
||||
prevPatterns.insert<TensorDimOpRewriter>(context);
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(prevPatterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
// initialize class member `onlyUsedByWithShapes`
|
||||
onlyUsedByWithShapes.clear();
|
||||
funcOp.walk([&](Operation *op) {
|
||||
calOnlyUsedByWithShapesRecursively(op, /*prevOutput=*/nullptr);
|
||||
});
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "onlyUsedByWithShapes table: \n";
|
||||
for (auto it : onlyUsedByWithShapes)
|
||||
llvm::dbgs() << *it << "\n";
|
||||
});
|
||||
|
||||
// collect all the shape.with_shape ops.
|
||||
std::vector<shape::WithOp> allWithOps;
|
||||
funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); });
|
||||
|
||||
DenseMap<Value, SmallVector<Operation *, 8>> clusters =
|
||||
constructClustersForEachShape(allWithOps, funcOp);
|
||||
constructShapeFunc(allWithOps, context, clusters, symbolTable,
|
||||
dynShape2ShapeFunc, funcOp, shapeMappingAnalysis);
|
||||
|
||||
for (shape::WithOp withOp : allWithOps) {
|
||||
Value value = withOp.getOperand();
|
||||
for (Operation *user : withOp.getResult().getUsers()) {
|
||||
if (Value valueOf = llvm::dyn_cast<shape::ValueOfOp>(user))
|
||||
valueOf.replaceAllUsesExcept(value, withOp);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply patterns, note this also performs DCE.
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp, {})))
|
||||
return signalPassFailure();
|
||||
});
|
||||
}
|
||||
|
||||
DenseMap<Value, SmallVector<Operation *, 8>>
|
||||
OutlineShapeComputationPass::constructClustersForEachShape(
|
||||
const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {
|
||||
DenseMap<Value, DenseSet<Operation *>> clusters;
|
||||
for (shape::WithOp withOp : allWithOps) {
|
||||
Value shape = withOp.getShape();
|
||||
if (clusters.count(shape) == 0)
|
||||
getClusterFromValue(shape, clusters);
|
||||
}
|
||||
return getOrderedClusters(clusters, funcOp);
|
||||
}
|
||||
|
||||
// The output of a cluster is the `shape`, and the inputs are the outputs of
|
||||
// operations who are not in `onlyUsedByWithShapes`
|
||||
void OutlineShapeComputationPass::getClusterFromValue(
|
||||
Value shape, DenseMap<Value, DenseSet<Operation *>> &clusters) {
|
||||
DenseSet<Operation *> cluster;
|
||||
|
||||
DenseSet<Operation *> visited;
|
||||
std::queue<Operation *> queue;
|
||||
|
||||
// defOp == nullptr means shape is the argument of the func op
|
||||
if (Operation *defOp = shape.getDefiningOp()) {
|
||||
visited.insert(defOp);
|
||||
queue.push(defOp);
|
||||
}
|
||||
while (!queue.empty()) {
|
||||
Operation *op = queue.front();
|
||||
queue.pop();
|
||||
if (onlyUsedByWithShapes.contains(op)) {
|
||||
cluster.insert(op);
|
||||
for (Value inp : op->getOperands()) {
|
||||
Operation *inpDefOp = inp.getDefiningOp();
|
||||
if (nullptr != inpDefOp && !visited.contains(inpDefOp)) {
|
||||
visited.insert(inpDefOp);
|
||||
queue.push(inpDefOp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
clusters[shape] = std::move(cluster);
|
||||
}
|
||||
|
||||
// Returns whether `op` is a shape.with_shape, or all the users' of `op`
|
||||
// eventually point to the shape operand of shape.with_shape ops
|
||||
bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
|
||||
Operation *op, Value prevOutput) {
|
||||
if (onlyUsedByWithShapes.contains(op))
|
||||
return true;
|
||||
|
||||
if (auto withOp = llvm::dyn_cast<shape::WithOp>(op))
|
||||
return withOp.getShape() == prevOutput;
|
||||
|
||||
if (op->use_empty())
|
||||
return false;
|
||||
|
||||
for (Value oup : op->getResults())
|
||||
for (Operation *user : oup.getUsers())
|
||||
if (!calOnlyUsedByWithShapesRecursively(user, oup))
|
||||
return false;
|
||||
|
||||
onlyUsedByWithShapes.insert(op);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::createOutlineShapeComputationPass() {
|
||||
return std::make_unique<OutlineShapeComputationPass>();
|
||||
}
|
|
@ -0,0 +1,208 @@
|
|||
// RUN: mlir-opt -outline-shape-computation -test-print-shape-mapping -split-input-file %s 2>&1 | FileCheck %s
|
||||
|
||||
// Two dynamic shapes: one of direct shape.shape_of(arg) and the other.
|
||||
func.func @two_dynamic_one_direct_shape(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) -> tensor<?x4x?xf32> {
|
||||
// CHECK-DAG: Shape for {{.*}} = "test.abs"({{.*}}> :: @shape_cal_0(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)
|
||||
// CHECK-DAG: Shape for {{.*}} = "test.concat"({{.*}}> :: @shape_cal_1(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)
|
||||
%c2 = arith.constant 2 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%0 = shape.shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
|
||||
%1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index
|
||||
%2 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
|
||||
%3 = shape.with_shape %2, %0 : tensor<?x4x?xf32>, tensor<3xindex>
|
||||
%4 = shape.value_of %3 : tensor<?x4x?xf32>
|
||||
%5 = "test.concat"(%4, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
|
||||
%6 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index
|
||||
%7 = arith.addi %6, %c2 : index
|
||||
%8 = shape.from_extents %7, %c4, %1 : index, index, index
|
||||
%9 = shape.with_shape %5, %8 : tensor<?x4x?xf32>, !shape.shape
|
||||
%10 = shape.value_of %9 : tensor<?x4x?xf32>
|
||||
return %10 : tensor<?x4x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @two_dynamic_one_direct_shape
|
||||
// CHECK-NEXT: %0 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
|
||||
// CHECK-NEXT: %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
|
||||
// CHECK-NEXT: return %1 : tensor<?x4x?xf32>
|
||||
|
||||
// CHECK: shape.func private @shape_cal_1(%arg0: tensor<?x4x?xf32>) -> !shape.shape {
|
||||
// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
|
||||
// CHECK-DAG: %[[V1:.*]] = get_extent %[[V0]], %c2 : tensor<3xindex>, index -> index
|
||||
// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index
|
||||
// CHECK-DAG: %[[V3:.*]] = arith.addi %[[V2]], %c2 : index
|
||||
// CHECK-DAG: %[[V4:.*]] = from_extents %[[V3]], %c4, %[[V1]] : index, index, index
|
||||
// CHECK-DAG: return %[[V4]] : !shape.shape
|
||||
|
||||
// CHECK: shape.func private @shape_cal_0(%arg0: tensor<?x4x?xf32>) -> tensor<3xindex> {
|
||||
// CHECK-DAG: %0 = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
|
||||
// CHECK-DAG: return %0 : tensor<3xindex>
|
||||
|
||||
// -----
|
||||
|
||||
// Two dynamic shapes and they share the same shape.func
|
||||
func.func @two_dynamic_share_same_shape(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) -> tensor<?x4x?xf32> {
|
||||
%c2 = arith.constant 2 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%0 = shape.shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
|
||||
%1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index
|
||||
%2 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
|
||||
%3 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index
|
||||
%4 = arith.addi %3, %c2 : index
|
||||
%5 = shape.from_extents %4, %c4, %1 : index, index, index
|
||||
%6 = shape.with_shape %2, %5 : tensor<?x4x?xf32>, !shape.shape
|
||||
%7 = shape.value_of %6 : tensor<?x4x?xf32>
|
||||
%8 = "test.abs"(%7) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
|
||||
%9 = shape.with_shape %8, %5 : tensor<?x4x?xf32>, !shape.shape
|
||||
%10 = shape.value_of %9 : tensor<?x4x?xf32>
|
||||
return %10 : tensor<?x4x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func.func @two_dynamic_share_same_shape
|
||||
// CHECK-NEXT: %0 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
|
||||
// CHECK-NEXT: %1 = "test.abs"(%0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
|
||||
// CHECK-NEXT: return %1 : tensor<?x4x?xf32>
|
||||
|
||||
// CHECK: shape.func private @shape_cal_0(%arg0: tensor<?x4x?xf32>) -> !shape.shape {
|
||||
// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
|
||||
// CHECK-DAG: %[[V1:.*]] = get_extent %[[V0]], %c2 : tensor<3xindex>, index -> index
|
||||
// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index
|
||||
// CHECK-DAG: %[[V3:.*]] = arith.addi %[[V2]], %c2 : index
|
||||
// CHECK-DAG: %[[V4:.*]] = from_extents %[[V3]], %c4, %[[V1]] : index, index, index
|
||||
// CHECK-DAG: return %4 : !shape.shape
|
||||
// CHECK-NOT: shape_cal_1
|
||||
|
||||
// -----
|
||||
|
||||
// There's an internal dynamic shape source, and two other dynamic shapes shares it
|
||||
func.func @internal_dynamic_shape_source_shared(%arg0: tensor<?x4xf32>) -> tensor<?xi32> {
|
||||
%0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32>
|
||||
%1 = shape.shape_of %0 : tensor<?xi32> -> tensor<1xindex>
|
||||
%2 = shape.with_shape %0, %1 : tensor<?xi32>, tensor<1xindex>
|
||||
%3 = shape.value_of %2 : tensor<?xi32>
|
||||
%4 = "test.abs"(%3) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
%5 = shape.with_shape %4, %1 : tensor<?xi32>, tensor<1xindex>
|
||||
%6 = shape.value_of %5 : tensor<?xi32>
|
||||
%7 = "test.negate"(%6) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
%8 = shape.with_shape %7, %1 : tensor<?xi32>, tensor<1xindex>
|
||||
%9 = shape.value_of %8 : tensor<?xi32>
|
||||
return %9 : tensor<?xi32>
|
||||
}
|
||||
// CHECK-LABEL: func.func @internal_dynamic_shape_source_shared
|
||||
// CHECK-NEXT: %0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32>
|
||||
// CHECK-NEXT: %1 = "test.abs"(%0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK-NEXT: %2 = "test.negate"(%1) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK-NEXT: return %2 : tensor<?xi32>
|
||||
|
||||
// CHECK: shape.func private @shape_cal_0(%arg0: tensor<?xi32>) -> tensor<1xindex> {
|
||||
// CHECK-NEXT: %0 = shape_of %arg0 : tensor<?xi32> -> tensor<1xindex>
|
||||
// CHECK-NEXT: return %0 : tensor<1xindex>
|
||||
// CHECK-NOT: shape_cal_1
|
||||
|
||||
// -----
|
||||
|
||||
// There's only a return op in the constructed shape.func
|
||||
func.func @only_return_of_constructed_shape(%arg0: tensor<?x4xf32>, %arg1: tensor<1xindex>) -> tensor<?xi32> {
|
||||
%0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32>
|
||||
%1 = shape.with_shape %0, %arg1 : tensor<?xi32>, tensor<1xindex>
|
||||
%2 = shape.value_of %1 : tensor<?xi32>
|
||||
return %2 : tensor<?xi32>
|
||||
}
|
||||
// CHECK-LABEL: func.func @only_return_of_constructed_shape(%arg0: tensor<?x4xf32>, %arg1: tensor<1xindex>) -> tensor<?xi32> {
|
||||
// CHECK-NEXT: %0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32>
|
||||
// CHECK-NEXT: return %0 : tensor<?xi32>
|
||||
|
||||
// CHECK: shape.func private @shape_cal_0(%arg0: tensor<1xindex>) -> tensor<1xindex> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<1xindex>
|
||||
|
||||
// -----
|
||||
|
||||
// Shape computation part interleaves with general computation.
|
||||
func.func @interleaved_shape_computation(%arg0: tensor<?x4x5xf32>, %arg1: tensor<?x4x5xf32>, %arg2: tensor<?x4x5xf32>) -> (tensor<?x4x5xf32>, index) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%c5 = arith.constant 5 : index
|
||||
%0 = shape.shape_of %arg0 : tensor<?x4x5xf32> -> tensor<3xindex>
|
||||
%1 = shape.shape_of %arg1 : tensor<?x4x5xf32> -> tensor<3xindex>
|
||||
%2 = shape.shape_of %arg2 : tensor<?x4x5xf32> -> tensor<3xindex>
|
||||
%3 = "test.concat"(%arg0, %arg1, %arg2) {axis = 0 : i64} : (tensor<?x4x5xf32>, tensor<?x4x5xf32>, tensor<?x4x5xf32>) -> tensor<?x4x5xf32>
|
||||
%4 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index
|
||||
%5 = shape.get_extent %1, %c0 : tensor<3xindex>, index -> index
|
||||
%6 = shape.get_extent %2, %c0 : tensor<3xindex>, index -> index
|
||||
%7 = arith.addi %4, %5 : index
|
||||
%8 = arith.addi %7, %6 : index
|
||||
%9 = shape.from_extents %8, %c4, %c5 : index, index, index
|
||||
%10 = shape.with_shape %3, %9 : tensor<?x4x5xf32>, !shape.shape
|
||||
%11 = shape.value_of %10 : tensor<?x4x5xf32>
|
||||
return %11, %7 : tensor<?x4x5xf32>, index
|
||||
}
|
||||
// CHECK-LABEL: func.func @interleaved_shape_computation
|
||||
// CHECK-DAG: %[[V0:.*]] = shape.shape_of %arg0 : tensor<?x4x5xf32> -> tensor<3xindex>
|
||||
// CHECK-DAG: %[[V1:.*]] = shape.shape_of %arg1 : tensor<?x4x5xf32> -> tensor<3xindex>
|
||||
// CHECK-DAG: %[[V2:.*]] = "test.concat"(%arg0, %arg1, %arg2) {axis = 0 : i64} : (tensor<?x4x5xf32>, tensor<?x4x5xf32>, tensor<?x4x5xf32>) -> tensor<?x4x5xf32>
|
||||
// CHECK-DAG: %[[V3:.*]] = shape.get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index
|
||||
// CHECK-DAG: %[[V4:.*]] = shape.get_extent %[[V1]], %c0 : tensor<3xindex>, index -> index
|
||||
// CHECK-DAG: %[[V5:.*]] = arith.addi %[[V3]], %[[V4]] : index
|
||||
// CHECK-DAG: return %[[V2]], %[[V5]] : tensor<?x4x5xf32>, index
|
||||
|
||||
// CHECK: shape.func private @shape_cal_0(%arg0: tensor<?x4x5xf32>, %arg1: index, %arg2: index) -> !shape.shape {
|
||||
// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor<?x4x5xf32> -> tensor<3xindex>
|
||||
// CHECK-DAG: %[[V1:.*]] = get_extent %[[V0]], %arg1 : tensor<3xindex>, index -> index
|
||||
// CHECK-DAG: %[[V2:.*]] = arith.addi %arg2, %[[V1]] : index
|
||||
// CHECK-DAG: %[[V3:.*]] = from_extents %[[V2]], %c4, %c5 : index, index, index
|
||||
// CHECK-DAG: return %[[V3]] : !shape.shape
|
||||
|
||||
// -----
|
||||
|
||||
// There're multiple reused shape computations.
|
||||
func.func @multiple_reused(%arg0: tensor<?x4xf32>, %arg1: tensor<?x4xf32>) -> (tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%0 = shape.shape_of %arg0 : tensor<?x4xf32> -> tensor<2xindex>
|
||||
%1 = shape.shape_of %arg1 : tensor<?x4xf32> -> tensor<2xindex>
|
||||
%2 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
|
||||
%3 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
|
||||
%4 = shape.get_extent %0, %c0 : tensor<2xindex>, index -> index
|
||||
%5 = shape.get_extent %1, %c0 : tensor<2xindex>, index -> index
|
||||
%6 = arith.addi %4, %5 : index
|
||||
%7 = shape.from_extents %6, %c4 : index, index
|
||||
%8 = shape.with_shape %2, %7 : tensor<?x4xf32>, !shape.shape
|
||||
%9 = shape.with_shape %3, %7 : tensor<?x4xf32>, !shape.shape
|
||||
%10 = shape.value_of %8 : tensor<?x4xf32>
|
||||
%11 = shape.value_of %9 : tensor<?x4xf32>
|
||||
%12 = "test.concat"(%arg0, %2) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
|
||||
%13 = "test.concat"(%arg0, %3) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
|
||||
%14 = arith.addi %6, %4 : index
|
||||
%15 = shape.from_extents %14, %c4 : index, index
|
||||
%16 = shape.with_shape %12, %15 : tensor<?x4xf32>, !shape.shape
|
||||
%17 = shape.with_shape %13, %15 : tensor<?x4xf32>, !shape.shape
|
||||
%18 = shape.value_of %16 : tensor<?x4xf32>
|
||||
%19 = shape.value_of %17 : tensor<?x4xf32>
|
||||
return %10, %11, %18, %19 : tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>
|
||||
}
|
||||
// CHECK-LABEL: func.func @multiple_reused
|
||||
// CHECK-DAG: %[[V0:.*]] = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
|
||||
// CHECK-DAG: %[[V1:.*]] = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
|
||||
// CHECK-DAG: %[[V2:.*]] = "test.concat"(%arg0, %[[V0]]) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
|
||||
// CHECK-DAG: %[[V3:.*]] = "test.concat"(%arg0, %[[V1]]) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
|
||||
// CHECK-DAG: return %[[V0]], %[[V1]], %[[V2]], %[[V3]] : tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>
|
||||
|
||||
// CHECK: shape.func private @shape_cal_1(%arg0: tensor<?x4xf32>, %arg1: tensor<?x4xf32>) -> !shape.shape {
|
||||
// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor<?x4xf32> -> tensor<2xindex>
|
||||
// CHECK-DAG: %[[V1:.*]] = shape_of %arg1 : tensor<?x4xf32> -> tensor<2xindex>
|
||||
// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<2xindex>, index -> index
|
||||
// CHECK-DAG: %[[V3:.*]] = get_extent %[[V1]], %c0 : tensor<2xindex>, index -> index
|
||||
// CHECK-DAG: %[[V4:.*]] = arith.addi %[[V2]], %[[V3]] : index
|
||||
// CHECK-DAG: %[[V5:.*]] = arith.addi %[[V4]], %[[V2]] : index
|
||||
// CHECK-DAG: %[[V6:.*]] = from_extents %[[V5]], %c4 : index, index
|
||||
// CHECK-DAG: return %[[V6]] : !shape.shape
|
||||
|
||||
// CHECK: shape.func private @shape_cal_0(%arg0: tensor<?x4xf32>, %arg1: tensor<?x4xf32>) -> !shape.shape {
|
||||
// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor<?x4xf32> -> tensor<2xindex>
|
||||
// CHECK-DAG: %[[V1:.*]] = shape_of %arg1 : tensor<?x4xf32> -> tensor<2xindex>
|
||||
// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<2xindex>, index -> index
|
||||
// CHECK-DAG: %[[V3:.*]] = get_extent %[[V1]], %c0 : tensor<2xindex>, index -> index
|
||||
// CHECK-DAG: %[[V4:.*]] = arith.addi %[[V2]], %[[V3]] : index
|
||||
// CHECK-DAG: %[[V5:.*]] = from_extents %[[V4]], %c4 : index, index
|
||||
// CHECK-DAG: return %[[V5]] : !shape.shape
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
# Exclude tests from libMLIR.so
|
||||
add_mlir_library(MLIRShapeTestPasses
|
||||
TestShapeFunctions.cpp
|
||||
TestShapeMappingAnalysis.cpp
|
||||
|
||||
EXCLUDE_FROM_LIBMLIR
|
||||
|
||||
|
@ -11,6 +12,7 @@ add_mlir_library(MLIRShapeTestPasses
|
|||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRShapeOpsTransforms
|
||||
MLIRShapeDialect
|
||||
MLIRSupport
|
||||
)
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
//===- TestShapeMappingInfo.cpp -------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
struct TestShapeMappingPass
|
||||
: public PassWrapper<TestShapeMappingPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestShapeMappingPass)
|
||||
|
||||
StringRef getArgument() const final { return "test-print-shape-mapping"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Print the contents of a constructed shape mapping information.";
|
||||
}
|
||||
void runOnOperation() override {
|
||||
llvm::Optional<std::reference_wrapper<shape::ShapeMappingAnalysis>>
|
||||
maybeAnalysis = getCachedAnalysis<shape::ShapeMappingAnalysis>();
|
||||
if (maybeAnalysis.has_value())
|
||||
maybeAnalysis.value().get().print(llvm::errs());
|
||||
else
|
||||
llvm::errs() << "No cached ShapeMappingAnalysis existed.";
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestShapeMappingPass() {
|
||||
PassRegistration<TestShapeMappingPass>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
|
@ -109,6 +109,7 @@ void registerTestPDLLPasses();
|
|||
void registerTestPreparationPassWithAllowedMemrefResults();
|
||||
void registerTestRecursiveTypesPass();
|
||||
void registerTestSCFUtilsPass();
|
||||
void registerTestShapeMappingPass();
|
||||
void registerTestSliceAnalysisPass();
|
||||
void registerTestTensorTransforms();
|
||||
void registerTestTilingInterface();
|
||||
|
@ -208,6 +209,7 @@ void registerTestPasses() {
|
|||
mlir::test::registerTestPDLLPasses();
|
||||
mlir::test::registerTestRecursiveTypesPass();
|
||||
mlir::test::registerTestSCFUtilsPass();
|
||||
mlir::test::registerTestShapeMappingPass();
|
||||
mlir::test::registerTestSliceAnalysisPass();
|
||||
mlir::test::registerTestTensorTransforms();
|
||||
mlir::test::registerTestTilingInterface();
|
||||
|
|
Loading…
Reference in New Issue