[mlir][shape] Split out structural type conversions for shape dialect.

A "structural" type conversion is one where the underlying ops are
completely agnostic to the actual types involved and simply need to update
their types. An example of this is shape.assuming -- the shape.assuming op
and the corresponding shape.assuming_yield op need to update their types
accordingly to the TypeConverter, but otherwise don't care what type
conversions are happening.

Also, the previous conversion code would not correctly materialize
conversions for the shape.assuming_yield op. This should have caused a
verification failure, but shape.assuming's verifier wasn't calling
RegionBranchOpInterface::verifyTypes (which for reasons can't be called
automatically as part of the trait verification, and requires being
called manually). This patch also adds that verification.

Differential Revision: https://reviews.llvm.org/D89833
This commit is contained in:
Sean Silva 2020-10-19 15:59:03 -07:00
parent f0292ede9b
commit 57b338c08a
6 changed files with 108 additions and 66 deletions

View File

@ -635,6 +635,7 @@ def Shape_AssumingOp : Shape_Op<"assuming",
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
let extraClassDeclaration = [{
// Inline the region into the region containing the AssumingOp and delete

View File

@ -17,7 +17,8 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
class BufferizeTypeConverter;
class ConversionTarget;
class TypeConverter;
} // namespace mlir
namespace mlir {
@ -40,9 +41,21 @@ void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns,
MLIRContext *ctx);
std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
void populateShapeTypeConversionPatterns(MLIRContext *ctx,
BufferizeTypeConverter &converter,
OwningRewritePatternList &patterns);
/// Populates patterns for shape dialect structural type conversions and sets up
/// the provided ConversionTarget with the appropriate legality configuration
/// for the ops to get converted properly.
///
/// A "structural" type conversion is one where the underlying ops are
/// completely agnostic to the actual types involved and simply need to update
/// their types consistently. An example of this is shape.assuming -- the
/// shape.assuming op and the corresponding shape.assuming_yield op need to have
/// consistent types, but the exact types don't matter. So all that we need to
/// do for a structural type conversion is to update both of their types
/// consistently to the new types prescribed by the TypeConverter.
void populateShapeStructuralTypeConversionsAndLegality(
MLIRContext *context, TypeConverter &typeConverter,
OwningRewritePatternList &patterns, ConversionTarget &target);
// Bufferizes shape dialect ops.
//
// Note that most shape dialect ops must be converted to std before

View File

@ -8,82 +8,30 @@
#include "mlir/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
using namespace mlir::shape;
namespace {
// Propagate tensor to memref conversions through shape.assuming ops.
class TypeConversionAssumingOpConverter
: public BufferizeOpConversionPattern<shape::AssumingOp> {
public:
using BufferizeOpConversionPattern<
shape::AssumingOp>::BufferizeOpConversionPattern;
LogicalResult
matchAndRewrite(shape::AssumingOp assumingOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
SmallVector<Type, 2> newResultTypes;
newResultTypes.reserve(assumingOp.getNumResults());
for (auto result : assumingOp.getResults()) {
auto originalType = result.getType();
Type convertedType = converter.convertType(originalType);
newResultTypes.push_back(convertedType);
}
auto newAssumingOp = rewriter.create<shape::AssumingOp>(
assumingOp.getLoc(), newResultTypes, assumingOp.witness());
rewriter.replaceOp(assumingOp, newAssumingOp.getResults());
rewriter.inlineRegionBefore(assumingOp.doRegion(), newAssumingOp.doRegion(),
newAssumingOp.doRegion().end());
return success();
}
};
struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
void runOnFunction() override {
MLIRContext &ctx = getContext();
OwningRewritePatternList patterns;
BufferizeTypeConverter converter;
populateShapeTypeConversionPatterns(&ctx, converter, patterns);
BufferizeTypeConverter typeConverter;
ConversionTarget target(getContext());
auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
target.addDynamicallyLegalOp<AssumingOp>([&](shape::AssumingOp op) {
return std::all_of(op.result_type_begin(), op.result_type_end(),
isMemRefType);
});
populateBufferizeMaterializationLegality(target);
populateShapeStructuralTypeConversionsAndLegality(&ctx, typeConverter,
patterns, target);
if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
if (failed(applyPartialConversion(getFunction(), target, patterns)))
signalPassFailure();
}
};
} // namespace
/// Populates `patterns` with the conversion patterns of tensor->memref.
//
// TODO: Change this to work generally with any type conversions.
void mlir::populateShapeTypeConversionPatterns(
MLIRContext *context, BufferizeTypeConverter &converter,
OwningRewritePatternList &patterns) {
patterns.insert<TypeConversionAssumingOpConverter>(context, converter);
}
//===----------------------------------------------------------------------===//
// ShapeBufferizePass construction
//===----------------------------------------------------------------------===//
std::unique_ptr<FunctionPass> mlir::createShapeBufferizePass() {
return std::make_unique<ShapeBufferizePass>();
}

View File

@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms
Bufferize.cpp
RemoveShapeConstraints.cpp
ShapeToShapeLowering.cpp
StructuralTypeConversions.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ShapeOps/Transforms

View File

@ -0,0 +1,71 @@
//===- StructuralTypeConversions.cpp - Shape structural type conversions --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace mlir::shape;
namespace {
class ConvertAssumingOpTypes : public OpConversionPattern<AssumingOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AssumingOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
SmallVector<Type, 2> newResultTypes;
newResultTypes.reserve(op.getNumResults());
for (auto result : op.getResults()) {
auto originalType = result.getType();
Type convertedType = getTypeConverter()->convertType(originalType);
newResultTypes.push_back(convertedType);
}
auto newAssumingOp =
rewriter.create<AssumingOp>(op.getLoc(), newResultTypes, op.witness());
rewriter.replaceOp(op, newAssumingOp.getResults());
rewriter.inlineRegionBefore(op.doRegion(), newAssumingOp.doRegion(),
newAssumingOp.doRegion().end());
return success();
}
};
} // namespace
namespace {
class ConvertAssumingYieldOpTypes
: public OpConversionPattern<AssumingYieldOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AssumingYieldOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewriter.replaceOpWithNewOp<AssumingYieldOp>(op, operands);
return success();
}
};
} // namespace
void mlir::populateShapeStructuralTypeConversionsAndLegality(
MLIRContext *context, TypeConverter &typeConverter,
OwningRewritePatternList &patterns, ConversionTarget &target) {
patterns.insert<ConvertAssumingOpTypes, ConvertAssumingYieldOpTypes>(
typeConverter, context);
target.addDynamicallyLegalOp<AssumingOp>([&](AssumingOp op) {
return typeConverter.isLegal(op.getResultTypes());
});
target.addDynamicallyLegalOp<AssumingYieldOp>([&](AssumingYieldOp op) {
return typeConverter.isLegal(op.getOperandTypes());
});
}

View File

@ -1,12 +1,20 @@
// RUN: mlir-opt -split-input-file -shape-bufferize <%s | FileCheck %s
// -----
// Check that shape.assuming returns a memref.
//
// CHECK-LABEL: @shape_assuming_returns_memref
func @shape_assuming_returns_memref() {
// CHECK-LABEL: func @shape_assuming() {
// CHECK: %[[WTRUE:.*]] = shape.const_witness true
// CHECK: %[[MEMREF:.*]] = shape.assuming %[[WTRUE]] -> (memref<2xf16>) {
// CHECK: %[[TENSOR_VAL:.*]] = "test.source"() : () -> tensor<2xf16>
// CHECK: %[[YIELDED_MEMREF:.*]] = tensor_to_memref %[[TENSOR_VAL]] : memref<2xf16>
// CHECK: shape.assuming_yield %[[YIELDED_MEMREF]] : memref<2xf16>
// CHECK: }
// CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF:.*]] : memref<2xf16>
// CHECK: "test.sink"(%[[TENSOR]]) : (tensor<2xf16>) -> ()
// CHECK: return
// CHECK: }
func @shape_assuming() {
%0 = shape.const_witness true
// CHECK: shape.assuming %{{.*}} -> (memref<2xf16>) {
%1 = shape.assuming %0 -> (tensor<2xf16>) {
%2 = "test.source"() : () -> (tensor<2xf16>)
shape.assuming_yield %2 : tensor<2xf16>