From 0d22a3fdc87cb8e96a73cb427c6621c405c4674e Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 2 Dec 2019 07:51:27 -0800 Subject: [PATCH] NFC: Update std.subview op to use AttrSizedOperandSegments This turns a few manually written helper methods into auto-generated ones. PiperOrigin-RevId: 283339617 --- mlir/include/mlir/Dialect/StandardOps/Ops.td | 50 +++---- mlir/include/mlir/IR/Builders.h | 2 + .../StandardToLLVM/ConvertStandardToLLVM.cpp | 3 +- mlir/lib/Dialect/StandardOps/Ops.cpp | 123 +++++++----------- mlir/lib/IR/Builders.cpp | 8 ++ 5 files changed, 77 insertions(+), 109 deletions(-) diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index e2731acf47fd..70cf3bb77751 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -1248,7 +1248,7 @@ def ViewOp : Std_Op<"view", [NoSideEffect]> { let hasCanonicalizer = 1; } -def SubViewOp : Std_Op<"subview", [NoSideEffect]> { +def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> { let summary = "memref subview operation"; let description = [{ The "subview" operation converts a memref type to another memref type @@ -1356,23 +1356,25 @@ def SubViewOp : Std_Op<"subview", [NoSideEffect]> { // TODO(b/144779634, ravishankarm) : Use different arguments for // offsets, sizes and strides. - let arguments = (ins AnyMemRef:$source, I32Attr:$num_offsets, - I32Attr:$num_sizes, I32Attr:$num_strides, - Variadic:$operands); + let arguments = (ins + AnyMemRef:$source, + Variadic:$offsets, + Variadic:$sizes, + Variadic:$strides, + I32ElementsAttr:$operand_segment_sizes + ); let results = (outs AnyMemRef); - let builders = [OpBuilder< - "Builder *b, OperationState &result, Value *source, " - "ArrayRef offsets, ArrayRef sizes, " - "ArrayRef strides, Type resultType = Type(), " - "ArrayRef attrs = {}">, + let builders = [ OpBuilder< - "Builder *builder, OperationState &result, Type resultType, Value *source">, + "Builder *b, OperationState &result, Value *source, " + "ArrayRef offsets, ArrayRef sizes, " + "ArrayRef strides, Type resultType = Type(), " + "ArrayRef attrs = {}">, OpBuilder< - "Builder *builder, OperationState &result, Type resultType, Value *source, " - "unsigned num_offsets, unsigned num_sizes, unsigned num_strides, " - "ArrayRef offsets, ArrayRef sizes, " - "ArrayRef strides">]; + "Builder *builder, OperationState &result, " + "Type resultType, Value *source"> + ]; let extraClassDeclaration = [{ /// Returns the type of the base memref operand. @@ -1384,28 +1386,16 @@ def SubViewOp : Std_Op<"subview", [NoSideEffect]> { MemRefType getType() { return getResult()->getType().cast(); } /// Returns as integer value the number of offset operands. - int64_t getNumOffsets() { - return num_offsets().getSExtValue(); - } + int64_t getNumOffsets() { return llvm::size(offsets()); } /// Returns as integer value the number of size operands. - int64_t getNumSizes() { - return num_sizes().getSExtValue(); - } + int64_t getNumSizes() { return llvm::size(sizes()); } /// Returns as integer value the number of stride operands. - int64_t getNumStrides() { - return num_strides().getSExtValue(); - } - - /// Returns the dynamic offsets for this subview operation. - operand_range getDynamicOffsets(); + int64_t getNumStrides() { return llvm::size(strides()); } /// Returns the dynamic sizes for this subview operation if specified. - operand_range getDynamicSizes(); - - /// Returns the dynamic strides for this subview operation if specified. - operand_range getDynamicStrides(); + operand_range getDynamicSizes() { return sizes(); } // Auxiliary range data structure and helper function that unpacks the // offset, size and stride operands of the SubViewOp into a list of triples. diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 01ad38cfc110..c5ed7b16b566 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -120,6 +120,8 @@ public: IntegerAttr getI32IntegerAttr(int32_t value); IntegerAttr getI64IntegerAttr(int64_t value); + DenseIntElementsAttr getI32VectorAttr(ArrayRef values); + ArrayAttr getAffineMapArrayAttr(ArrayRef values); ArrayAttr getI32ArrayAttr(ArrayRef values); ArrayAttr getI64ArrayAttr(ArrayRef values); diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index ae2b7837c406..d226766a3fc6 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1476,7 +1476,6 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto viewOp = cast(op); - SubViewOpOperandAdaptor adaptor(operands); // TODO(b/144779634, ravishankarm) : After Tblgen is adapted to support // having multiple variadic operands where each operand can have different // number of entries, clean all of this up. @@ -1518,7 +1517,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { return matchFailure(); // Create the descriptor. - MemRefDescriptor sourceMemRef(adaptor.source()); + MemRefDescriptor sourceMemRef(operands.front()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Copy the buffer pointer from the old descriptor to the new one. diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 0bf562337a94..31431be50543 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1370,7 +1370,7 @@ OpFoldResult DimOp::fold(ArrayRef operands) { // Fold dim to the size argument of a SubViewOp. auto memref = memrefOrTensor()->getDefiningOp(); if (auto subview = dyn_cast_or_null(memref)) { - auto sizes = subview.getDynamicSizes(); + auto sizes = subview.sizes(); if (!sizes.empty()) return *(sizes.begin() + getIndex()); } @@ -2563,35 +2563,23 @@ static Type inferSubViewResultType(MemRefType memRefType) { memRefType.getMemorySpace()); } -void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType, - Value *source, unsigned num_offsets, - unsigned num_sizes, unsigned num_strides, - ArrayRef offsets, ArrayRef sizes, - ArrayRef strides) { - SmallVector operands; - operands.reserve(num_offsets + num_sizes + num_strides); - operands.append(offsets.begin(), offsets.end()); - operands.append(sizes.begin(), sizes.end()); - operands.append(strides.begin(), strides.end()); - build(b, result, resultType, source, b->getI32IntegerAttr(num_offsets), - b->getI32IntegerAttr(num_sizes), b->getI32IntegerAttr(num_strides), - operands); -} - void mlir::SubViewOp::build(Builder *b, OperationState &result, Value *source, ArrayRef offsets, ArrayRef sizes, ArrayRef strides, Type resultType, ArrayRef attrs) { if (!resultType) resultType = inferSubViewResultType(source->getType().cast()); - build(b, result, resultType, source, offsets.size(), sizes.size(), - strides.size(), offsets, sizes, strides); + auto segmentAttr = b->getI32VectorAttr( + {1, static_cast(offsets.size()), static_cast(sizes.size()), + static_cast(strides.size())}); + build(b, result, resultType, source, offsets, sizes, strides, segmentAttr); result.addAttributes(attrs); } void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType, Value *source) { - build(b, result, resultType, source, 0, 0, 0, {}, {}, {}); + build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{}, + resultType); } static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { @@ -2607,12 +2595,13 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { parser.parseOperandList(stridesInfo, OpAsmParser::Delimiter::Square)) { return failure(); } + auto builder = parser.getBuilder(); - result.addAttribute("num_offsets", - builder.getI32IntegerAttr(offsetsInfo.size())); - result.addAttribute("num_sizes", builder.getI32IntegerAttr(sizesInfo.size())); - result.addAttribute("num_strides", - builder.getI32IntegerAttr(stridesInfo.size())); + result.addAttribute( + SubViewOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({1, static_cast(offsetsInfo.size()), + static_cast(sizesInfo.size()), + static_cast(stridesInfo.size())})); return failure( parser.parseOptionalAttrDict(result.attributes) || @@ -2627,14 +2616,15 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { static void print(OpAsmPrinter &p, SubViewOp op) { p << op.getOperationName() << ' ' << *op.getOperand(0) << '['; - p.printOperands(op.getDynamicOffsets()); + p.printOperands(op.offsets()); p << "]["; - p.printOperands(op.getDynamicSizes()); + p.printOperands(op.sizes()); p << "]["; - p.printOperands(op.getDynamicStrides()); + p.printOperands(op.strides()); p << ']'; - SmallVector elidedAttrs = {"num_offsets", "num_sizes", - "num_strides"}; + + SmallVector elidedAttrs = { + SubViewOp::getOperandSegmentSizeAttr()}; p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); p << " : " << op.getOperand(0)->getType() << " to " << op.getType(); } @@ -2689,14 +2679,16 @@ static LogicalResult verify(SubViewOp op) { } // Verify that if the shape of the subview type is static, then sizes are not - // dynamic values, and viceversa. + // dynamic values, and vice versa. if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) || (op.getNumSizes() == 0 && !subViewType.hasStaticShape())) { return op.emitError("invalid to specify dynamic sizes when subview result " "type is statically shaped and viceversa"); } + + // Verify that if dynamic sizes are specified, then the result memref type + // have full dynamic dimensions. if (op.getNumSizes() > 0) { - // Verify that non if the shape values of the result type are static. if (llvm::any_of(subViewType.getShape(), [](int64_t dim) { return dim != ShapedType::kDynamicSize; })) { @@ -2758,9 +2750,8 @@ SmallVector SubViewOp::getRanges() { unsigned rank = getType().getRank(); res.reserve(rank); for (unsigned i = 0; i < rank; ++i) - res.emplace_back(Range{*(getDynamicOffsets().begin() + i), - *(getDynamicSizes().begin() + i), - *(getDynamicStrides().begin() + i)}); + res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i), + *(strides().begin() + i)}); return res; } @@ -2792,13 +2783,13 @@ public: // Follow all or nothing approach for shapes for now. If all the operands // for sizes are constants then fold it into the type of the result memref. if (subViewType.hasStaticShape() || - llvm::any_of(subViewOp.getDynamicSizes(), [](Value *operand) { + llvm::any_of(subViewOp.sizes(), [](Value *operand) { return !matchPattern(operand, m_ConstantIndex()); })) { return matchFailure(); } SmallVector staticShape(subViewOp.getNumSizes()); - for (auto size : enumerate(subViewOp.getDynamicSizes())) { + for (auto size : enumerate(subViewOp.sizes())) { auto defOp = size.value()->getDefiningOp(); assert(defOp); staticShape[size.index()] = cast(defOp).getValue(); @@ -2808,12 +2799,12 @@ public: subViewType.getMemorySpace()); auto newSubViewOp = rewriter.create( subViewOp.getLoc(), subViewOp.source(), - llvm::to_vector<4>(subViewOp.getDynamicOffsets()), ArrayRef(), - llvm::to_vector<4>(subViewOp.getDynamicStrides()), newMemRefType); + llvm::to_vector<4>(subViewOp.offsets()), ArrayRef(), + llvm::to_vector<4>(subViewOp.strides()), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp( - llvm::to_vector<4>(subViewOp.getDynamicSizes()), subViewOp, - newSubViewOp, subViewOp.getType()); + llvm::to_vector<4>(subViewOp.sizes()), subViewOp, newSubViewOp, + subViewOp.getType()); return matchSuccess(); } }; @@ -2839,14 +2830,14 @@ public: failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) || llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || - llvm::any_of(subViewOp.getDynamicStrides(), [](Value *stride) { + llvm::any_of(subViewOp.strides(), [](Value *stride) { return !matchPattern(stride, m_ConstantIndex()); })) { return matchFailure(); } SmallVector staticStrides(subViewOp.getNumStrides()); - for (auto stride : enumerate(subViewOp.getDynamicStrides())) { + for (auto stride : enumerate(subViewOp.strides())) { auto defOp = stride.value()->getDefiningOp(); assert(defOp); assert(baseStrides[stride.index()] > 0); @@ -2858,15 +2849,15 @@ public: MemRefType newMemRefType = MemRefType::get(subViewType.getShape(), subViewType.getElementType(), layoutMap, subViewType.getMemorySpace()); - auto newSubViewOp = rewriter.create( - subViewOp.getLoc(), subViewOp.source(), - llvm::to_vector<4>(subViewOp.getDynamicOffsets()), - llvm::to_vector<4>(subViewOp.getDynamicSizes()), ArrayRef(), - newMemRefType); + auto newSubViewOp = + rewriter.create(subViewOp.getLoc(), subViewOp.source(), + llvm::to_vector<4>(subViewOp.offsets()), + llvm::to_vector<4>(subViewOp.sizes()), + ArrayRef(), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp( - llvm::to_vector<4>(subViewOp.getDynamicStrides()), subViewOp, - newSubViewOp, subViewOp.getType()); + llvm::to_vector<4>(subViewOp.strides()), subViewOp, newSubViewOp, + subViewOp.getType()); return matchSuccess(); } }; @@ -2893,14 +2884,14 @@ public: llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || baseOffset == MemRefType::getDynamicStrideOrOffset() || - llvm::any_of(subViewOp.getDynamicOffsets(), [](Value *stride) { + llvm::any_of(subViewOp.offsets(), [](Value *stride) { return !matchPattern(stride, m_ConstantIndex()); })) { return matchFailure(); } auto staticOffset = baseOffset; - for (auto offset : enumerate(subViewOp.getDynamicOffsets())) { + for (auto offset : enumerate(subViewOp.offsets())) { auto defOp = offset.value()->getDefiningOp(); assert(defOp); assert(baseStrides[offset.index()] > 0); @@ -2915,39 +2906,17 @@ public: layoutMap, subViewType.getMemorySpace()); auto newSubViewOp = rewriter.create( subViewOp.getLoc(), subViewOp.source(), ArrayRef(), - llvm::to_vector<4>(subViewOp.getDynamicSizes()), - llvm::to_vector<4>(subViewOp.getDynamicStrides()), newMemRefType); + llvm::to_vector<4>(subViewOp.sizes()), + llvm::to_vector<4>(subViewOp.strides()), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp( - llvm::to_vector<4>(subViewOp.getDynamicOffsets()), subViewOp, - newSubViewOp, subViewOp.getType()); + llvm::to_vector<4>(subViewOp.offsets()), subViewOp, newSubViewOp, + subViewOp.getType()); return matchSuccess(); } }; } // end anonymous namespace -SubViewOp::operand_range SubViewOp::getDynamicOffsets() { - auto numOffsets = getNumOffsets(); - assert(getNumOperands() >= numOffsets + 1); - return {operand_begin() + 1, operand_begin() + 1 + numOffsets}; -} - -SubViewOp::operand_range SubViewOp::getDynamicSizes() { - auto numSizes = getNumSizes(); - auto numOffsets = getNumOffsets(); - assert(getNumOperands() >= numSizes + numOffsets + 1); - return {operand_begin() + 1 + numOffsets, - operand_begin() + 1 + numOffsets + numSizes}; -} - -SubViewOp::operand_range SubViewOp::getDynamicStrides() { - auto numSizes = getNumSizes(); - auto numOffsets = getNumOffsets(); - auto numStrides = getNumStrides(); - assert(getNumOperands() >= numSizes + numOffsets + numStrides + 1); - return {operand_begin() + (1 + numOffsets + numSizes), - operand_begin() + (1 + numOffsets + numSizes + numStrides)}; -} void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index afdeefd023c2..4d6cd3550cac 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -100,6 +100,14 @@ IntegerAttr Builder::getI64IntegerAttr(int64_t value) { return IntegerAttr::get(getIntegerType(64), APInt(64, value)); } +DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef values) { + return DenseElementsAttr::get( + VectorType::get(static_cast(values.size()), + getIntegerType(32)), + values) + .cast(); +} + IntegerAttr Builder::getI32IntegerAttr(int32_t value) { return IntegerAttr::get(getIntegerType(32), APInt(32, value)); }