diff --git a/mlir/docs/Dialects/Linalg/_index.md b/mlir/docs/Dialects/Linalg/_index.md index 3c2742ac51f1..790f858dad26 100644 --- a/mlir/docs/Dialects/Linalg/_index.md +++ b/mlir/docs/Dialects/Linalg/_index.md @@ -520,7 +520,6 @@ generally alias the operand `view`. At the moment the existing ops are: * `memref.view`, * `memref.subview`, * `memref.transpose`. -* `linalg.range`, * `linalg.slice`, * `linalg.reshape`, ``` diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td index de3703b71acb..de5bc6d33e67 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -58,8 +58,4 @@ def Linalg_Dialect : Dialect { }]; } -// Whether a type is a RangeType. -def LinalgIsRangeTypePred : CPred<"$_self.isa()">; -def Range : DialectType; - #endif // LINALG_BASE diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index df39b311f410..a5c756198192 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -330,34 +330,6 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor", let hasFolder = 1; } -def Linalg_RangeOp : - Linalg_Op<"range", [NoSideEffect]>, - Arguments<(ins Index:$min, Index:$max, Index:$step)>, - Results<(outs Range)> { - let summary = "Create a `range` type value, used to create `view`s"; - let description = [{ - The `linalg.range` op creates a `!linalg.range` from 3 values of type - `index` that represent the min, max and step values of the `range`. This - type does not pass function boundaries at the moment. - - Example: - - ```mlir - %3 = linalg.range %0:%1:%2 : !linalg.range - ```` - }]; - let builders = [ - OpBuilder<(ins "Value":$min, "Value":$max, "Value":$step), - [{ - auto rangeType = RangeType::get($_builder.getContext()); - build($_builder, $_state, rangeType, min, max, step); - }]>]; - - // Fully specified by traits. - let verifier = ?; - let assemblyFormat = "$min `:` $max `:` $step attr-dict `:` type(results)"; -} - def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>, Arguments<(ins Variadic:$values)> { let summary = "Linalg yield operation"; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h index 396cc3b59120..3c99ecb4dda1 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -22,27 +22,4 @@ #include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc" -namespace mlir { -class MLIRContext; - -namespace linalg { - -/// A RangeType represents a minimal range abstraction (min, max, step). -/// It is constructed by calling the linalg.range op with three values index of -/// index type: -/// -/// ```mlir -/// func @foo(%arg0 : index, %arg1 : index, %arg2 : index) { -/// %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range -/// } -/// ``` -class RangeType : public Type::TypeBase { -public: - // Used for generic hooks in TypeBase. - using Base::Base; -}; - -} // namespace linalg -} // namespace mlir - #endif // MLIR_DIALECT_LINALG_LINALGTYPES_H_ diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index 713890425acc..478d73c07ac7 100644 --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -52,48 +52,7 @@ static Type getPtrToElementType(T containerType, LLVMTypeConverter &lowering) { lowering.convertType(containerType.getElementType())); } -/// Convert the given range descriptor type to the LLVMIR dialect. -/// Range descriptor contains the range bounds and the step as 64-bit integers. -/// -/// struct { -/// int64_t min; -/// int64_t max; -/// int64_t step; -/// }; -static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) { - auto *context = t.getContext(); - auto int64Ty = converter.convertType(IntegerType::get(context, 64)); - return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty}); -} - namespace { -// RangeOp creates a new range descriptor. -class RangeOpConversion : public ConvertOpToLLVMPattern { -public: - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(RangeOp rangeOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto rangeDescriptorTy = convertRangeType( - rangeOp.getType().cast(), *getTypeConverter()); - - ImplicitLocOpBuilder b(rangeOp->getLoc(), rewriter); - - // Fill in an aggregate value of the descriptor. - Value desc = b.create(rangeDescriptorTy); - desc = b.create(desc, adaptor.min(), - rewriter.getI64ArrayAttr(0)); - desc = b.create(desc, adaptor.max(), - rewriter.getI64ArrayAttr(1)); - desc = b.create(desc, adaptor.step(), - rewriter.getI64ArrayAttr(2)); - rewriter.replaceOp(rangeOp, desc); - return success(); - } -}; - - // YieldOp produces and LLVM::ReturnOp. class YieldOpConversion : public ConvertOpToLLVMPattern { public: @@ -111,11 +70,7 @@ public: /// Populate the given list with patterns that convert from Linalg to LLVM. void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add(converter); - - // Populate the type conversions for the linalg types. - converter.addConversion( - [&](RangeType type) { return convertRangeType(type, converter); }); + patterns.add(converter); } namespace { @@ -135,7 +90,6 @@ void ConvertLinalgToLLVMPass::runOnOperation() { populateMemRefToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); - target.addIllegalOp(); target.addLegalOp(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp index 6227d21521e4..e0f909d48297 100644 --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -187,7 +187,7 @@ void ConvertLinalgToStandardPass::runOnOperation() { target.addLegalDialect(); - target.addLegalOp(); + target.addLegalOp(); RewritePatternSet patterns(&getContext()); populateLinalgToStandardConversionPatterns(patterns); if (failed(applyFullConversion(module, target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp index 4d004352c683..c06c0f4a76c2 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp @@ -106,7 +106,6 @@ void addNamedOpBuilders( } void mlir::linalg::LinalgDialect::initialize() { - addTypes(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" @@ -125,29 +124,6 @@ void mlir::linalg::LinalgDialect::initialize() { addInterfaces(); } -Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser) const { - // Parse the main keyword for the type. - StringRef keyword; - if (parser.parseKeyword(&keyword)) - return Type(); - MLIRContext *context = getContext(); - - // Handle 'range' types. - if (keyword == "range") - return RangeType::get(context); - - parser.emitError(parser.getNameLoc(), "unknown Linalg type: " + keyword); - return Type(); -} - -/// RangeType prints as just "range". -static void print(RangeType rt, DialectAsmPrinter &os) { os << "range"; } - -void mlir::linalg::LinalgDialect::printType(Type type, - DialectAsmPrinter &os) const { - print(type.cast(), os); -} - LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { using comprehensive_bufferize::BufferizableOpInterface; diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 5c7549cd45ad..ab51fe5445b5 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -298,16 +298,6 @@ func @generic(%arg0: memref) { // // // ----- -// expected-error @+1 {{unknown Linalg type}} -!invalid_type = type !linalg.unknown - -// ----- - -// expected-error @+1 {{expected valid keyword}} -!invalid_type = type !linalg<"?"> - -// ----- - func @named_ops(%a3: memref, %b3: memref, %c3: memref) { // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #1 (3)}} linalg.batch_matmul ins(%a3, %b3: memref, memref) diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir deleted file mode 100644 index f6ab826ae151..000000000000 --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: mlir-opt %s -convert-linalg-to-llvm | FileCheck %s - -func @range(%arg0: index) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %R = linalg.range %c0:%arg0:%c1 : !linalg.range - return -} -// CHECK-LABEL: func @range -// CHECK: arith.constant 0 : index -// CHECK: arith.constant 1 : index -// CHECK: llvm.mlir.undef : !llvm.struct<(i64, i64, i64)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(i64, i64, i64)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(i64, i64, i64)> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(i64, i64, i64)> diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index cac1d4448e53..f9559cbd5b96 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -86,20 +86,10 @@ func @pad_to_static_size(%arg0: tensor, %ub0: index, %ub1: index, // ----- -func @range(%arg0: index, %arg1: index, %arg2: index) { - %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range - return -} -// CHECK-LABEL: func @range(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { -// CHECK-NEXT: linalg.range %{{.*}} : %{{.*}} : %{{.*}} : !linalg.range - -// ----- - -func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { +func @views(%arg0: index) { %c0 = arith.constant 0 : index %0 = arith.muli %arg0, %arg0 : index %1 = memref.alloc (%0) : memref - %2 = linalg.range %arg0:%arg1:%arg2 : !linalg.range %3 = memref.view %1[%c0][%arg0, %arg0] : memref to memref %4 = memref.view %1[%c0][%arg0, %arg0] : memref to memref> memref.dealloc %1 : memref @@ -108,7 +98,6 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index // CHECK-LABEL: func @views // CHECK: arith.muli %{{.*}}, %{{.*}} : index // CHECK-NEXT: memref.alloc(%{{.*}}) : memref -// CHECK-NEXT: range // CHECK-NEXT: memref.view %{{.*}}[%{{.*}}][%{{.*}}] : // CHECK-SAME: memref to memref // CHECK-NEXT: memref.view %{{.*}}[%{{.*}}][%{{.*}}] :