forked from OSchip/llvm-project
[mlir][linalg] Remove RangeOp and RangeType.
Remove the RangeOp and the RangeType that are not actively used anymore. After removing RangeType, the LinalgTypes header only includes the generated dialect header. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D115727
This commit is contained in:
parent
7029a688c9
commit
9912bed730
|
@ -520,7 +520,6 @@ generally alias the operand `view`. At the moment the existing ops are:
|
||||||
* `memref.view`,
|
* `memref.view`,
|
||||||
* `memref.subview`,
|
* `memref.subview`,
|
||||||
* `memref.transpose`.
|
* `memref.transpose`.
|
||||||
* `linalg.range`,
|
|
||||||
* `linalg.slice`,
|
* `linalg.slice`,
|
||||||
* `linalg.reshape`,
|
* `linalg.reshape`,
|
||||||
```
|
```
|
||||||
|
|
|
@ -58,8 +58,4 @@ def Linalg_Dialect : Dialect {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Whether a type is a RangeType.
|
|
||||||
def LinalgIsRangeTypePred : CPred<"$_self.isa<RangeType>()">;
|
|
||||||
def Range : DialectType<Linalg_Dialect, LinalgIsRangeTypePred, "range">;
|
|
||||||
|
|
||||||
#endif // LINALG_BASE
|
#endif // LINALG_BASE
|
||||||
|
|
|
@ -330,34 +330,6 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Linalg_RangeOp :
|
|
||||||
Linalg_Op<"range", [NoSideEffect]>,
|
|
||||||
Arguments<(ins Index:$min, Index:$max, Index:$step)>,
|
|
||||||
Results<(outs Range)> {
|
|
||||||
let summary = "Create a `range` type value, used to create `view`s";
|
|
||||||
let description = [{
|
|
||||||
The `linalg.range` op creates a `!linalg.range` from 3 values of type
|
|
||||||
`index` that represent the min, max and step values of the `range`. This
|
|
||||||
type does not pass function boundaries at the moment.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```mlir
|
|
||||||
%3 = linalg.range %0:%1:%2 : !linalg.range
|
|
||||||
````
|
|
||||||
}];
|
|
||||||
let builders = [
|
|
||||||
OpBuilder<(ins "Value":$min, "Value":$max, "Value":$step),
|
|
||||||
[{
|
|
||||||
auto rangeType = RangeType::get($_builder.getContext());
|
|
||||||
build($_builder, $_state, rangeType, min, max, step);
|
|
||||||
}]>];
|
|
||||||
|
|
||||||
// Fully specified by traits.
|
|
||||||
let verifier = ?;
|
|
||||||
let assemblyFormat = "$min `:` $max `:` $step attr-dict `:` type(results)";
|
|
||||||
}
|
|
||||||
|
|
||||||
def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>,
|
def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>,
|
||||||
Arguments<(ins Variadic<AnyType>:$values)> {
|
Arguments<(ins Variadic<AnyType>:$values)> {
|
||||||
let summary = "Linalg yield operation";
|
let summary = "Linalg yield operation";
|
||||||
|
|
|
@ -22,27 +22,4 @@
|
||||||
|
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
|
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
class MLIRContext;
|
|
||||||
|
|
||||||
namespace linalg {
|
|
||||||
|
|
||||||
/// A RangeType represents a minimal range abstraction (min, max, step).
|
|
||||||
/// It is constructed by calling the linalg.range op with three values index of
|
|
||||||
/// index type:
|
|
||||||
///
|
|
||||||
/// ```mlir
|
|
||||||
/// func @foo(%arg0 : index, %arg1 : index, %arg2 : index) {
|
|
||||||
/// %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
class RangeType : public Type::TypeBase<RangeType, Type, TypeStorage> {
|
|
||||||
public:
|
|
||||||
// Used for generic hooks in TypeBase.
|
|
||||||
using Base::Base;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace linalg
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
#endif // MLIR_DIALECT_LINALG_LINALGTYPES_H_
|
#endif // MLIR_DIALECT_LINALG_LINALGTYPES_H_
|
||||||
|
|
|
@ -52,48 +52,7 @@ static Type getPtrToElementType(T containerType, LLVMTypeConverter &lowering) {
|
||||||
lowering.convertType(containerType.getElementType()));
|
lowering.convertType(containerType.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert the given range descriptor type to the LLVMIR dialect.
|
|
||||||
/// Range descriptor contains the range bounds and the step as 64-bit integers.
|
|
||||||
///
|
|
||||||
/// struct {
|
|
||||||
/// int64_t min;
|
|
||||||
/// int64_t max;
|
|
||||||
/// int64_t step;
|
|
||||||
/// };
|
|
||||||
static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
|
|
||||||
auto *context = t.getContext();
|
|
||||||
auto int64Ty = converter.convertType(IntegerType::get(context, 64));
|
|
||||||
return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty});
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// RangeOp creates a new range descriptor.
|
|
||||||
class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> {
|
|
||||||
public:
|
|
||||||
using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern;
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(RangeOp rangeOp, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
auto rangeDescriptorTy = convertRangeType(
|
|
||||||
rangeOp.getType().cast<RangeType>(), *getTypeConverter());
|
|
||||||
|
|
||||||
ImplicitLocOpBuilder b(rangeOp->getLoc(), rewriter);
|
|
||||||
|
|
||||||
// Fill in an aggregate value of the descriptor.
|
|
||||||
Value desc = b.create<LLVM::UndefOp>(rangeDescriptorTy);
|
|
||||||
desc = b.create<LLVM::InsertValueOp>(desc, adaptor.min(),
|
|
||||||
rewriter.getI64ArrayAttr(0));
|
|
||||||
desc = b.create<LLVM::InsertValueOp>(desc, adaptor.max(),
|
|
||||||
rewriter.getI64ArrayAttr(1));
|
|
||||||
desc = b.create<LLVM::InsertValueOp>(desc, adaptor.step(),
|
|
||||||
rewriter.getI64ArrayAttr(2));
|
|
||||||
rewriter.replaceOp(rangeOp, desc);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
// YieldOp produces and LLVM::ReturnOp.
|
// YieldOp produces and LLVM::ReturnOp.
|
||||||
class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
|
class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -111,11 +70,7 @@ public:
|
||||||
/// Populate the given list with patterns that convert from Linalg to LLVM.
|
/// Populate the given list with patterns that convert from Linalg to LLVM.
|
||||||
void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns) {
|
||||||
patterns.add<RangeOpConversion, YieldOpConversion>(converter);
|
patterns.add<YieldOpConversion>(converter);
|
||||||
|
|
||||||
// Populate the type conversions for the linalg types.
|
|
||||||
converter.addConversion(
|
|
||||||
[&](RangeType type) { return convertRangeType(type, converter); });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -135,7 +90,6 @@ void ConvertLinalgToLLVMPass::runOnOperation() {
|
||||||
populateMemRefToLLVMConversionPatterns(converter, patterns);
|
populateMemRefToLLVMConversionPatterns(converter, patterns);
|
||||||
|
|
||||||
LLVMConversionTarget target(getContext());
|
LLVMConversionTarget target(getContext());
|
||||||
target.addIllegalOp<RangeOp>();
|
|
||||||
target.addLegalOp<ModuleOp>();
|
target.addLegalOp<ModuleOp>();
|
||||||
if (failed(applyPartialConversion(module, target, std::move(patterns))))
|
if (failed(applyPartialConversion(module, target, std::move(patterns))))
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
|
|
@ -187,7 +187,7 @@ void ConvertLinalgToStandardPass::runOnOperation() {
|
||||||
target.addLegalDialect<AffineDialect, arith::ArithmeticDialect,
|
target.addLegalDialect<AffineDialect, arith::ArithmeticDialect,
|
||||||
memref::MemRefDialect, scf::SCFDialect,
|
memref::MemRefDialect, scf::SCFDialect,
|
||||||
StandardOpsDialect>();
|
StandardOpsDialect>();
|
||||||
target.addLegalOp<ModuleOp, FuncOp, ReturnOp, linalg::RangeOp>();
|
target.addLegalOp<ModuleOp, FuncOp, ReturnOp>();
|
||||||
RewritePatternSet patterns(&getContext());
|
RewritePatternSet patterns(&getContext());
|
||||||
populateLinalgToStandardConversionPatterns(patterns);
|
populateLinalgToStandardConversionPatterns(patterns);
|
||||||
if (failed(applyFullConversion(module, target, std::move(patterns))))
|
if (failed(applyFullConversion(module, target, std::move(patterns))))
|
||||||
|
|
|
@ -106,7 +106,6 @@ void addNamedOpBuilders(
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::linalg::LinalgDialect::initialize() {
|
void mlir::linalg::LinalgDialect::initialize() {
|
||||||
addTypes<RangeType>();
|
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
|
||||||
|
@ -125,29 +124,6 @@ void mlir::linalg::LinalgDialect::initialize() {
|
||||||
addInterfaces<LinalgInlinerInterface>();
|
addInterfaces<LinalgInlinerInterface>();
|
||||||
}
|
}
|
||||||
|
|
||||||
Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser) const {
|
|
||||||
// Parse the main keyword for the type.
|
|
||||||
StringRef keyword;
|
|
||||||
if (parser.parseKeyword(&keyword))
|
|
||||||
return Type();
|
|
||||||
MLIRContext *context = getContext();
|
|
||||||
|
|
||||||
// Handle 'range' types.
|
|
||||||
if (keyword == "range")
|
|
||||||
return RangeType::get(context);
|
|
||||||
|
|
||||||
parser.emitError(parser.getNameLoc(), "unknown Linalg type: " + keyword);
|
|
||||||
return Type();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// RangeType prints as just "range".
|
|
||||||
static void print(RangeType rt, DialectAsmPrinter &os) { os << "range"; }
|
|
||||||
|
|
||||||
void mlir::linalg::LinalgDialect::printType(Type type,
|
|
||||||
DialectAsmPrinter &os) const {
|
|
||||||
print(type.cast<RangeType>(), os);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op,
|
LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op,
|
||||||
NamedAttribute attr) {
|
NamedAttribute attr) {
|
||||||
using comprehensive_bufferize::BufferizableOpInterface;
|
using comprehensive_bufferize::BufferizableOpInterface;
|
||||||
|
|
|
@ -298,16 +298,6 @@ func @generic(%arg0: memref<?x?xi4>) {
|
||||||
//
|
//
|
||||||
// // -----
|
// // -----
|
||||||
|
|
||||||
// expected-error @+1 {{unknown Linalg type}}
|
|
||||||
!invalid_type = type !linalg.unknown
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// expected-error @+1 {{expected valid keyword}}
|
|
||||||
!invalid_type = type !linalg<"?">
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?xf32>) {
|
func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?xf32>) {
|
||||||
// expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #1 (3)}}
|
// expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #1 (3)}}
|
||||||
linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?xf32>)
|
linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?xf32>)
|
||||||
|
|
|
@ -1,15 +0,0 @@
|
||||||
// RUN: mlir-opt %s -convert-linalg-to-llvm | FileCheck %s
|
|
||||||
|
|
||||||
func @range(%arg0: index) {
|
|
||||||
%c0 = arith.constant 0 : index
|
|
||||||
%c1 = arith.constant 1 : index
|
|
||||||
%R = linalg.range %c0:%arg0:%c1 : !linalg.range
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// CHECK-LABEL: func @range
|
|
||||||
// CHECK: arith.constant 0 : index
|
|
||||||
// CHECK: arith.constant 1 : index
|
|
||||||
// CHECK: llvm.mlir.undef : !llvm.struct<(i64, i64, i64)>
|
|
||||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(i64, i64, i64)>
|
|
||||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(i64, i64, i64)>
|
|
||||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(i64, i64, i64)>
|
|
|
@ -86,20 +86,10 @@ func @pad_to_static_size(%arg0: tensor<?x?xf32>, %ub0: index, %ub1: index,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @range(%arg0: index, %arg1: index, %arg2: index) {
|
func @views(%arg0: index) {
|
||||||
%0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// CHECK-LABEL: func @range(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
|
||||||
// CHECK-NEXT: linalg.range %{{.*}} : %{{.*}} : %{{.*}} : !linalg.range
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
|
|
||||||
%c0 = arith.constant 0 : index
|
%c0 = arith.constant 0 : index
|
||||||
%0 = arith.muli %arg0, %arg0 : index
|
%0 = arith.muli %arg0, %arg0 : index
|
||||||
%1 = memref.alloc (%0) : memref<?xi8>
|
%1 = memref.alloc (%0) : memref<?xi8>
|
||||||
%2 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
|
|
||||||
%3 = memref.view %1[%c0][%arg0, %arg0] : memref<?xi8> to memref<?x?xf32>
|
%3 = memref.view %1[%c0][%arg0, %arg0] : memref<?xi8> to memref<?x?xf32>
|
||||||
%4 = memref.view %1[%c0][%arg0, %arg0] : memref<?xi8> to memref<?x?xvector<4x4xf32>>
|
%4 = memref.view %1[%c0][%arg0, %arg0] : memref<?xi8> to memref<?x?xvector<4x4xf32>>
|
||||||
memref.dealloc %1 : memref<?xi8>
|
memref.dealloc %1 : memref<?xi8>
|
||||||
|
@ -108,7 +98,6 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index
|
||||||
// CHECK-LABEL: func @views
|
// CHECK-LABEL: func @views
|
||||||
// CHECK: arith.muli %{{.*}}, %{{.*}} : index
|
// CHECK: arith.muli %{{.*}}, %{{.*}} : index
|
||||||
// CHECK-NEXT: memref.alloc(%{{.*}}) : memref<?xi8>
|
// CHECK-NEXT: memref.alloc(%{{.*}}) : memref<?xi8>
|
||||||
// CHECK-NEXT: range
|
|
||||||
// CHECK-NEXT: memref.view %{{.*}}[%{{.*}}][%{{.*}}] :
|
// CHECK-NEXT: memref.view %{{.*}}[%{{.*}}][%{{.*}}] :
|
||||||
// CHECK-SAME: memref<?xi8> to memref<?x?xf32>
|
// CHECK-SAME: memref<?xi8> to memref<?x?xf32>
|
||||||
// CHECK-NEXT: memref.view %{{.*}}[%{{.*}}][%{{.*}}] :
|
// CHECK-NEXT: memref.view %{{.*}}[%{{.*}}][%{{.*}}] :
|
||||||
|
|
Loading…
Reference in New Issue