[mlir][linalg] Remove RangeOp and RangeType.

Remove the RangeOp and the RangeType that are not actively used anymore. After removing RangeType, the LinalgTypes header only includes the generated dialect header.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D115727
This commit is contained in:
gysit 2021-12-15 07:10:32 +00:00
parent 7029a688c9
commit 9912bed730
10 changed files with 3 additions and 165 deletions

View File

@ -520,7 +520,6 @@ generally alias the operand `view`. At the moment the existing ops are:
* `memref.view`, * `memref.view`,
* `memref.subview`, * `memref.subview`,
* `memref.transpose`. * `memref.transpose`.
* `linalg.range`,
* `linalg.slice`, * `linalg.slice`,
* `linalg.reshape`, * `linalg.reshape`,
``` ```

View File

@ -58,8 +58,4 @@ def Linalg_Dialect : Dialect {
}]; }];
} }
// Whether a type is a RangeType.
def LinalgIsRangeTypePred : CPred<"$_self.isa<RangeType>()">;
def Range : DialectType<Linalg_Dialect, LinalgIsRangeTypePred, "range">;
#endif // LINALG_BASE #endif // LINALG_BASE

View File

@ -330,34 +330,6 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
let hasFolder = 1; let hasFolder = 1;
} }
def Linalg_RangeOp :
Linalg_Op<"range", [NoSideEffect]>,
Arguments<(ins Index:$min, Index:$max, Index:$step)>,
Results<(outs Range)> {
let summary = "Create a `range` type value, used to create `view`s";
let description = [{
The `linalg.range` op creates a `!linalg.range` from 3 values of type
`index` that represent the min, max and step values of the `range`. This
type does not pass function boundaries at the moment.
Example:
```mlir
%3 = linalg.range %0:%1:%2 : !linalg.range
````
}];
let builders = [
OpBuilder<(ins "Value":$min, "Value":$max, "Value":$step),
[{
auto rangeType = RangeType::get($_builder.getContext());
build($_builder, $_state, rangeType, min, max, step);
}]>];
// Fully specified by traits.
let verifier = ?;
let assemblyFormat = "$min `:` $max `:` $step attr-dict `:` type(results)";
}
def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>, def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>,
Arguments<(ins Variadic<AnyType>:$values)> { Arguments<(ins Variadic<AnyType>:$values)> {
let summary = "Linalg yield operation"; let summary = "Linalg yield operation";

View File

@ -22,27 +22,4 @@
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc" #include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
namespace mlir {
class MLIRContext;
namespace linalg {
/// A RangeType represents a minimal range abstraction (min, max, step).
/// It is constructed by calling the linalg.range op with three values index of
/// index type:
///
/// ```mlir
/// func @foo(%arg0 : index, %arg1 : index, %arg2 : index) {
/// %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
/// }
/// ```
class RangeType : public Type::TypeBase<RangeType, Type, TypeStorage> {
public:
// Used for generic hooks in TypeBase.
using Base::Base;
};
} // namespace linalg
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_LINALGTYPES_H_ #endif // MLIR_DIALECT_LINALG_LINALGTYPES_H_

View File

@ -52,48 +52,7 @@ static Type getPtrToElementType(T containerType, LLVMTypeConverter &lowering) {
lowering.convertType(containerType.getElementType())); lowering.convertType(containerType.getElementType()));
} }
/// Convert the given range descriptor type to the LLVMIR dialect.
/// Range descriptor contains the range bounds and the step as 64-bit integers.
///
/// struct {
/// int64_t min;
/// int64_t max;
/// int64_t step;
/// };
static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
auto *context = t.getContext();
auto int64Ty = converter.convertType(IntegerType::get(context, 64));
return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty});
}
namespace { namespace {
// RangeOp creates a new range descriptor.
class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> {
public:
using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(RangeOp rangeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto rangeDescriptorTy = convertRangeType(
rangeOp.getType().cast<RangeType>(), *getTypeConverter());
ImplicitLocOpBuilder b(rangeOp->getLoc(), rewriter);
// Fill in an aggregate value of the descriptor.
Value desc = b.create<LLVM::UndefOp>(rangeDescriptorTy);
desc = b.create<LLVM::InsertValueOp>(desc, adaptor.min(),
rewriter.getI64ArrayAttr(0));
desc = b.create<LLVM::InsertValueOp>(desc, adaptor.max(),
rewriter.getI64ArrayAttr(1));
desc = b.create<LLVM::InsertValueOp>(desc, adaptor.step(),
rewriter.getI64ArrayAttr(2));
rewriter.replaceOp(rangeOp, desc);
return success();
}
};
// YieldOp produces and LLVM::ReturnOp. // YieldOp produces and LLVM::ReturnOp.
class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> { class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
public: public:
@ -111,11 +70,7 @@ public:
/// Populate the given list with patterns that convert from Linalg to LLVM. /// Populate the given list with patterns that convert from Linalg to LLVM.
void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter, void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) { RewritePatternSet &patterns) {
patterns.add<RangeOpConversion, YieldOpConversion>(converter); patterns.add<YieldOpConversion>(converter);
// Populate the type conversions for the linalg types.
converter.addConversion(
[&](RangeType type) { return convertRangeType(type, converter); });
} }
namespace { namespace {
@ -135,7 +90,6 @@ void ConvertLinalgToLLVMPass::runOnOperation() {
populateMemRefToLLVMConversionPatterns(converter, patterns); populateMemRefToLLVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext()); LLVMConversionTarget target(getContext());
target.addIllegalOp<RangeOp>();
target.addLegalOp<ModuleOp>(); target.addLegalOp<ModuleOp>();
if (failed(applyPartialConversion(module, target, std::move(patterns)))) if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure(); signalPassFailure();

View File

@ -187,7 +187,7 @@ void ConvertLinalgToStandardPass::runOnOperation() {
target.addLegalDialect<AffineDialect, arith::ArithmeticDialect, target.addLegalDialect<AffineDialect, arith::ArithmeticDialect,
memref::MemRefDialect, scf::SCFDialect, memref::MemRefDialect, scf::SCFDialect,
StandardOpsDialect>(); StandardOpsDialect>();
target.addLegalOp<ModuleOp, FuncOp, ReturnOp, linalg::RangeOp>(); target.addLegalOp<ModuleOp, FuncOp, ReturnOp>();
RewritePatternSet patterns(&getContext()); RewritePatternSet patterns(&getContext());
populateLinalgToStandardConversionPatterns(patterns); populateLinalgToStandardConversionPatterns(patterns);
if (failed(applyFullConversion(module, target, std::move(patterns)))) if (failed(applyFullConversion(module, target, std::move(patterns))))

View File

@ -106,7 +106,6 @@ void addNamedOpBuilders(
} }
void mlir::linalg::LinalgDialect::initialize() { void mlir::linalg::LinalgDialect::initialize() {
addTypes<RangeType>();
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
@ -125,29 +124,6 @@ void mlir::linalg::LinalgDialect::initialize() {
addInterfaces<LinalgInlinerInterface>(); addInterfaces<LinalgInlinerInterface>();
} }
Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser) const {
// Parse the main keyword for the type.
StringRef keyword;
if (parser.parseKeyword(&keyword))
return Type();
MLIRContext *context = getContext();
// Handle 'range' types.
if (keyword == "range")
return RangeType::get(context);
parser.emitError(parser.getNameLoc(), "unknown Linalg type: " + keyword);
return Type();
}
/// RangeType prints as just "range".
static void print(RangeType rt, DialectAsmPrinter &os) { os << "range"; }
void mlir::linalg::LinalgDialect::printType(Type type,
DialectAsmPrinter &os) const {
print(type.cast<RangeType>(), os);
}
LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op, LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) { NamedAttribute attr) {
using comprehensive_bufferize::BufferizableOpInterface; using comprehensive_bufferize::BufferizableOpInterface;

View File

@ -298,16 +298,6 @@ func @generic(%arg0: memref<?x?xi4>) {
// //
// // ----- // // -----
// expected-error @+1 {{unknown Linalg type}}
!invalid_type = type !linalg.unknown
// -----
// expected-error @+1 {{expected valid keyword}}
!invalid_type = type !linalg<"?">
// -----
func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?xf32>) { func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?xf32>) {
// expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #1 (3)}} // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #1 (3)}}
linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?xf32>) linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?xf32>)

View File

@ -1,15 +0,0 @@
// RUN: mlir-opt %s -convert-linalg-to-llvm | FileCheck %s
func @range(%arg0: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%R = linalg.range %c0:%arg0:%c1 : !linalg.range
return
}
// CHECK-LABEL: func @range
// CHECK: arith.constant 0 : index
// CHECK: arith.constant 1 : index
// CHECK: llvm.mlir.undef : !llvm.struct<(i64, i64, i64)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(i64, i64, i64)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(i64, i64, i64)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(i64, i64, i64)>

View File

@ -86,20 +86,10 @@ func @pad_to_static_size(%arg0: tensor<?x?xf32>, %ub0: index, %ub1: index,
// ----- // -----
func @range(%arg0: index, %arg1: index, %arg2: index) { func @views(%arg0: index) {
%0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
return
}
// CHECK-LABEL: func @range(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK-NEXT: linalg.range %{{.*}} : %{{.*}} : %{{.*}} : !linalg.range
// -----
func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
%c0 = arith.constant 0 : index %c0 = arith.constant 0 : index
%0 = arith.muli %arg0, %arg0 : index %0 = arith.muli %arg0, %arg0 : index
%1 = memref.alloc (%0) : memref<?xi8> %1 = memref.alloc (%0) : memref<?xi8>
%2 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
%3 = memref.view %1[%c0][%arg0, %arg0] : memref<?xi8> to memref<?x?xf32> %3 = memref.view %1[%c0][%arg0, %arg0] : memref<?xi8> to memref<?x?xf32>
%4 = memref.view %1[%c0][%arg0, %arg0] : memref<?xi8> to memref<?x?xvector<4x4xf32>> %4 = memref.view %1[%c0][%arg0, %arg0] : memref<?xi8> to memref<?x?xvector<4x4xf32>>
memref.dealloc %1 : memref<?xi8> memref.dealloc %1 : memref<?xi8>
@ -108,7 +98,6 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index
// CHECK-LABEL: func @views // CHECK-LABEL: func @views
// CHECK: arith.muli %{{.*}}, %{{.*}} : index // CHECK: arith.muli %{{.*}}, %{{.*}} : index
// CHECK-NEXT: memref.alloc(%{{.*}}) : memref<?xi8> // CHECK-NEXT: memref.alloc(%{{.*}}) : memref<?xi8>
// CHECK-NEXT: range
// CHECK-NEXT: memref.view %{{.*}}[%{{.*}}][%{{.*}}] : // CHECK-NEXT: memref.view %{{.*}}[%{{.*}}][%{{.*}}] :
// CHECK-SAME: memref<?xi8> to memref<?x?xf32> // CHECK-SAME: memref<?xi8> to memref<?x?xf32>
// CHECK-NEXT: memref.view %{{.*}}[%{{.*}}][%{{.*}}] : // CHECK-NEXT: memref.view %{{.*}}[%{{.*}}][%{{.*}}] :