From 11d144c57642ff9a7f393bc6a4809f75007ff73f Mon Sep 17 00:00:00 2001 From: gysit Date: Mon, 28 Feb 2022 11:25:12 +0000 Subject: [PATCH] [mlir][linalg] Check the iterator types are valid. Improve the LinalgOp verification to ensure the iterator types is known. Previously, unknown iterator types have been ignored without warning, which can lead to confusing bugs. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D120649 --- mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 9 +++++++++ mlir/test/Dialect/Linalg/invalid.mlir | 13 +++++++++++++ 2 files changed, 22 insertions(+) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index 86ef1210c747..84e26b150fa3 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -573,6 +573,15 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { << ") to be equal to the number of output tensors (" << linalgOp.getOutputTensorOperands().size() << ")"; + // Check all iterator types are known. + auto iteratorTypesRange = + linalgOp.iterator_types().getAsValueRange(); + for (StringRef iteratorType : iteratorTypesRange) { + if (!llvm::is_contained(getAllIteratorTypeNames(), iteratorType)) + return op->emitOpError("unexpected iterator_type (") + << iteratorType << ")"; + } + // Before checking indexing maps, we need to make sure the attributes // referenced by it are valid. if (linalgOp.hasDynamicIndexingMaps()) diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 081df97b7a0f..f220d8ecf3a5 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -95,6 +95,19 @@ func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) { // ----- +func @generic_wrong_iterator(%arg0: memref<1xi32>) { + // expected-error @+1 {{op unexpected iterator_type (random)}} + linalg.generic { + indexing_maps = [ affine_map<(i) -> (i)> ], + iterator_types = ["random"]} + outs(%arg0 : memref<1xi32>) { + ^bb(%i : i32): + linalg.yield %i : i32 + } +} + +// ----- + func @generic_one_d_view(%arg0: memref(off + i)>>) { // expected-error @+1 {{expected operand rank (1) to match the result rank of indexing_map #0 (2)}} linalg.generic {