forked from OSchip/llvm-project
[mlir][shape] Split out structural type conversions for shape dialect.
A "structural" type conversion is one where the underlying ops are completely agnostic to the actual types involved and simply need to update their types. An example of this is shape.assuming -- the shape.assuming op and the corresponding shape.assuming_yield op need to update their types accordingly to the TypeConverter, but otherwise don't care what type conversions are happening. Also, the previous conversion code would not correctly materialize conversions for the shape.assuming_yield op. This should have caused a verification failure, but shape.assuming's verifier wasn't calling RegionBranchOpInterface::verifyTypes (which for reasons can't be called automatically as part of the trait verification, and requires being called manually). This patch also adds that verification. Differential Revision: https://reviews.llvm.org/D89833
This commit is contained in:
parent
f0292ede9b
commit
57b338c08a
|
@ -635,6 +635,7 @@ def Shape_AssumingOp : Shape_Op<"assuming",
|
|||
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Inline the region into the region containing the AssumingOp and delete
|
||||
|
|
|
@ -17,7 +17,8 @@
|
|||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
class BufferizeTypeConverter;
|
||||
class ConversionTarget;
|
||||
class TypeConverter;
|
||||
} // namespace mlir
|
||||
|
||||
namespace mlir {
|
||||
|
@ -40,9 +41,21 @@ void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns,
|
|||
MLIRContext *ctx);
|
||||
std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
|
||||
|
||||
void populateShapeTypeConversionPatterns(MLIRContext *ctx,
|
||||
BufferizeTypeConverter &converter,
|
||||
OwningRewritePatternList &patterns);
|
||||
/// Populates patterns for shape dialect structural type conversions and sets up
|
||||
/// the provided ConversionTarget with the appropriate legality configuration
|
||||
/// for the ops to get converted properly.
|
||||
///
|
||||
/// A "structural" type conversion is one where the underlying ops are
|
||||
/// completely agnostic to the actual types involved and simply need to update
|
||||
/// their types consistently. An example of this is shape.assuming -- the
|
||||
/// shape.assuming op and the corresponding shape.assuming_yield op need to have
|
||||
/// consistent types, but the exact types don't matter. So all that we need to
|
||||
/// do for a structural type conversion is to update both of their types
|
||||
/// consistently to the new types prescribed by the TypeConverter.
|
||||
void populateShapeStructuralTypeConversionsAndLegality(
|
||||
MLIRContext *context, TypeConverter &typeConverter,
|
||||
OwningRewritePatternList &patterns, ConversionTarget &target);
|
||||
|
||||
// Bufferizes shape dialect ops.
|
||||
//
|
||||
// Note that most shape dialect ops must be converted to std before
|
||||
|
|
|
@ -8,82 +8,30 @@
|
|||
|
||||
#include "mlir/Transforms/Bufferize.h"
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::shape;
|
||||
|
||||
namespace {
|
||||
// Propagate tensor to memref conversions through shape.assuming ops.
|
||||
class TypeConversionAssumingOpConverter
|
||||
: public BufferizeOpConversionPattern<shape::AssumingOp> {
|
||||
public:
|
||||
using BufferizeOpConversionPattern<
|
||||
shape::AssumingOp>::BufferizeOpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(shape::AssumingOp assumingOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
SmallVector<Type, 2> newResultTypes;
|
||||
newResultTypes.reserve(assumingOp.getNumResults());
|
||||
for (auto result : assumingOp.getResults()) {
|
||||
auto originalType = result.getType();
|
||||
Type convertedType = converter.convertType(originalType);
|
||||
newResultTypes.push_back(convertedType);
|
||||
}
|
||||
|
||||
auto newAssumingOp = rewriter.create<shape::AssumingOp>(
|
||||
assumingOp.getLoc(), newResultTypes, assumingOp.witness());
|
||||
|
||||
rewriter.replaceOp(assumingOp, newAssumingOp.getResults());
|
||||
rewriter.inlineRegionBefore(assumingOp.doRegion(), newAssumingOp.doRegion(),
|
||||
newAssumingOp.doRegion().end());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
|
||||
void runOnFunction() override {
|
||||
MLIRContext &ctx = getContext();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
BufferizeTypeConverter converter;
|
||||
populateShapeTypeConversionPatterns(&ctx, converter, patterns);
|
||||
|
||||
BufferizeTypeConverter typeConverter;
|
||||
ConversionTarget target(getContext());
|
||||
auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
|
||||
|
||||
target.addDynamicallyLegalOp<AssumingOp>([&](shape::AssumingOp op) {
|
||||
return std::all_of(op.result_type_begin(), op.result_type_end(),
|
||||
isMemRefType);
|
||||
});
|
||||
populateBufferizeMaterializationLegality(target);
|
||||
populateShapeStructuralTypeConversionsAndLegality(&ctx, typeConverter,
|
||||
patterns, target);
|
||||
|
||||
if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
|
||||
if (failed(applyPartialConversion(getFunction(), target, patterns)))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Populates `patterns` with the conversion patterns of tensor->memref.
|
||||
//
|
||||
// TODO: Change this to work generally with any type conversions.
|
||||
void mlir::populateShapeTypeConversionPatterns(
|
||||
MLIRContext *context, BufferizeTypeConverter &converter,
|
||||
OwningRewritePatternList &patterns) {
|
||||
patterns.insert<TypeConversionAssumingOpConverter>(context, converter);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShapeBufferizePass construction
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::unique_ptr<FunctionPass> mlir::createShapeBufferizePass() {
|
||||
return std::make_unique<ShapeBufferizePass>();
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms
|
|||
Bufferize.cpp
|
||||
RemoveShapeConstraints.cpp
|
||||
ShapeToShapeLowering.cpp
|
||||
StructuralTypeConversions.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ShapeOps/Transforms
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
//===- StructuralTypeConversions.cpp - Shape structural type conversions --===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::shape;
|
||||
|
||||
namespace {
|
||||
class ConvertAssumingOpTypes : public OpConversionPattern<AssumingOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AssumingOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
SmallVector<Type, 2> newResultTypes;
|
||||
newResultTypes.reserve(op.getNumResults());
|
||||
for (auto result : op.getResults()) {
|
||||
auto originalType = result.getType();
|
||||
Type convertedType = getTypeConverter()->convertType(originalType);
|
||||
newResultTypes.push_back(convertedType);
|
||||
}
|
||||
|
||||
auto newAssumingOp =
|
||||
rewriter.create<AssumingOp>(op.getLoc(), newResultTypes, op.witness());
|
||||
|
||||
rewriter.replaceOp(op, newAssumingOp.getResults());
|
||||
rewriter.inlineRegionBefore(op.doRegion(), newAssumingOp.doRegion(),
|
||||
newAssumingOp.doRegion().end());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAssumingYieldOpTypes
|
||||
: public OpConversionPattern<AssumingYieldOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AssumingYieldOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
rewriter.replaceOpWithNewOp<AssumingYieldOp>(op, operands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::populateShapeStructuralTypeConversionsAndLegality(
|
||||
MLIRContext *context, TypeConverter &typeConverter,
|
||||
OwningRewritePatternList &patterns, ConversionTarget &target) {
|
||||
patterns.insert<ConvertAssumingOpTypes, ConvertAssumingYieldOpTypes>(
|
||||
typeConverter, context);
|
||||
target.addDynamicallyLegalOp<AssumingOp>([&](AssumingOp op) {
|
||||
return typeConverter.isLegal(op.getResultTypes());
|
||||
});
|
||||
target.addDynamicallyLegalOp<AssumingYieldOp>([&](AssumingYieldOp op) {
|
||||
return typeConverter.isLegal(op.getOperandTypes());
|
||||
});
|
||||
}
|
|
@ -1,12 +1,20 @@
|
|||
// RUN: mlir-opt -split-input-file -shape-bufferize <%s | FileCheck %s
|
||||
|
||||
// -----
|
||||
// Check that shape.assuming returns a memref.
|
||||
//
|
||||
// CHECK-LABEL: @shape_assuming_returns_memref
|
||||
func @shape_assuming_returns_memref() {
|
||||
|
||||
// CHECK-LABEL: func @shape_assuming() {
|
||||
// CHECK: %[[WTRUE:.*]] = shape.const_witness true
|
||||
// CHECK: %[[MEMREF:.*]] = shape.assuming %[[WTRUE]] -> (memref<2xf16>) {
|
||||
// CHECK: %[[TENSOR_VAL:.*]] = "test.source"() : () -> tensor<2xf16>
|
||||
// CHECK: %[[YIELDED_MEMREF:.*]] = tensor_to_memref %[[TENSOR_VAL]] : memref<2xf16>
|
||||
// CHECK: shape.assuming_yield %[[YIELDED_MEMREF]] : memref<2xf16>
|
||||
// CHECK: }
|
||||
// CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF:.*]] : memref<2xf16>
|
||||
// CHECK: "test.sink"(%[[TENSOR]]) : (tensor<2xf16>) -> ()
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
func @shape_assuming() {
|
||||
%0 = shape.const_witness true
|
||||
// CHECK: shape.assuming %{{.*}} -> (memref<2xf16>) {
|
||||
%1 = shape.assuming %0 -> (tensor<2xf16>) {
|
||||
%2 = "test.source"() : () -> (tensor<2xf16>)
|
||||
shape.assuming_yield %2 : tensor<2xf16>
|
||||
|
|
Loading…
Reference in New Issue