From 789c88e80e878ed866a2d8cfe29c7fd36082274c Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 22 Nov 2021 08:52:40 +0000 Subject: [PATCH] [mlir] Fix unintentional mutation by VectorType/RankedTensorType::Builder dropDim Differential Revision: https://reviews.llvm.org/D113933 --- mlir/include/mlir/IR/BuiltinTypes.h | 38 ++++++++++++------- .../Dialect/Linalg/Transforms/Transforms.cpp | 18 ++++++--- mlir/lib/Dialect/Vector/VectorTransforms.cpp | 11 +++--- 3 files changed, 42 insertions(+), 25 deletions(-) diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index b4ce23a72915..f3d2c24073dc 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -283,12 +283,14 @@ public: return *this; } - /// Create a new RankedTensor by erasing a dim from shape @pos. - RankedTensorType dropDim(unsigned pos) { + /// Erase a dim from shape @pos. + Builder &dropDim(unsigned pos) { assert(pos < shape.size() && "overflow"); - SmallVector newShape(shape.begin(), shape.end()); - newShape.erase(newShape.begin() + pos); - return setShape(newShape); + if (storage.empty()) + storage.append(shape.begin(), shape.end()); + storage.erase(storage.begin() + pos); + shape = {storage.data(), storage.size()}; + return *this; } operator RankedTensorType() { @@ -297,6 +299,8 @@ public: private: ArrayRef shape; + // Owning shape data for copy-on-write operations. + SmallVector storage; Type elementType; Attribute encoding; }; @@ -327,23 +331,29 @@ public: return *this; } - /// Create a new VectorType by erasing a dim from shape @pos. + /// Erase a dim from shape @pos. + Builder &dropDim(unsigned pos) { + assert(pos < shape.size() && "overflow"); + if (storage.empty()) + storage.append(shape.begin(), shape.end()); + storage.erase(storage.begin() + pos); + shape = {storage.data(), storage.size()}; + return *this; + } + /// In the particular case where the vector has a single dimension that we /// drop, return the scalar element type. // TODO: unify once we have a VectorType that supports 0-D. - Type dropDim(unsigned pos) { - assert(pos < shape.size() && "overflow"); - if (shape.size() == 1) + operator Type() { + if (shape.empty()) return elementType; - SmallVector newShape(shape.begin(), shape.end()); - newShape.erase(newShape.begin() + pos); - return setShape(newShape); + return VectorType::get(shape, elementType); } - operator VectorType() { return VectorType::get(shape, elementType); } - private: ArrayRef shape; + // Owning shape data for copy-on-write operations. + SmallVector storage; Type elementType; }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 657f2b760558..36bb0171823f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -876,9 +876,12 @@ struct DownscaleSizeOneWindowed2DConvolution final // Get new shapes and types for all operands by removing the size-1 // dimension. using RTTBuilder = RankedTensorType::Builder; - auto newInputType = RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); - auto newFilterType = RTTBuilder(filterType).dropDim((removeH ? 0 : 1)); - auto newOutputType = RTTBuilder(outputType).dropDim(removeH ? 1 : 2); + RankedTensorType newInputType = + RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); + RankedTensorType newFilterType = + RTTBuilder(filterType).dropDim((removeH ? 0 : 1)); + RankedTensorType newOutputType = + RTTBuilder(outputType).dropDim(removeH ? 1 : 2); // Rank-reduce operands. Location loc = convOp.getLoc(); @@ -948,9 +951,12 @@ struct DownscaleDepthwiseConv2DNhwcHwcOp final // Get new shapes and types for all operands by removing the size-1 // dimension. using RTTBuilder = RankedTensorType::Builder; - auto newInputType = RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); - auto newKernelType = RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); - auto newOutputType = RTTBuilder(outputType).dropDim(removeH ? 1 : 2); + RankedTensorType newInputType = + RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); + RankedTensorType newKernelType = + RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); + RankedTensorType newOutputType = + RTTBuilder(outputType).dropDim(removeH ? 1 : 2); // Rank-reduce operands. Location loc = convOp.getLoc(); diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index 37f3c31e6a48..5760e80bfcaf 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -94,15 +94,16 @@ static Value reshapeLoad(Location loc, Value val, VectorType type, } // Unroll leading dimensions. VectorType vType = lowType.cast(); - auto resType = VectorType::Builder(type).dropDim(index).cast(); + Type resType = VectorType::Builder(type).dropDim(index); + auto resVectorType = resType.cast(); Value result = rewriter.create( - loc, resType, rewriter.getZeroAttr(resType)); - for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) { + loc, resVectorType, rewriter.getZeroAttr(resVectorType)); + for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { auto posAttr = rewriter.getI64ArrayAttr(d); Value ext = rewriter.create(loc, vType, val, posAttr); Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); - result = - rewriter.create(loc, resType, load, result, posAttr); + result = rewriter.create(loc, resVectorType, load, result, + posAttr); } return result; }