From 2d70eff80229be18e5f688b04164f96f52b52714 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Sun, 26 Jun 2022 19:12:38 -0700 Subject: [PATCH] [mlir] Flip more uses to prefixed accessor form (NFC). Try to keep the final flip small. Need to flip MemRef as there are many templated cases with it and Tensor. --- mlir/include/mlir/Dialect/AMX/AMX.td | 22 +- .../mlir/Dialect/Tensor/IR/TensorOps.td | 18 +- .../AffineToStandard/AffineToStandard.cpp | 29 +-- .../TosaToLinalg/TosaToLinalgNamed.cpp | 2 +- mlir/lib/Dialect/SCF/IR/SCF.cpp | 7 +- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 232 +++++++++--------- .../Tensor/Transforms/SplitPadding.cpp | 2 +- 7 files changed, 160 insertions(+), 152 deletions(-) diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td index a1df2eea4c42..f86281213c89 100644 --- a/mlir/include/mlir/Dialect/AMX/AMX.td +++ b/mlir/include/mlir/Dialect/AMX/AMX.td @@ -97,7 +97,7 @@ def TileZeroOp : AMX_Op<"tile_zero", [NoSideEffect]> { VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res); let extraClassDeclaration = [{ VectorType getVectorType() { - return res().getType().cast(); + return getRes().getType().cast(); } }]; let assemblyFormat = "attr-dict `:` type($res)"; @@ -128,10 +128,10 @@ def TileLoadOp : AMX_Op<"tile_load", [NoSideEffect]> { VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res); let extraClassDeclaration = [{ MemRefType getMemRefType() { - return base().getType().cast(); + return getBase().getType().cast(); } VectorType getVectorType() { - return res().getType().cast(); + return getRes().getType().cast(); } }]; let assemblyFormat = "$base `[` $indices `]` attr-dict `:` " @@ -158,10 +158,10 @@ def TileStoreOp : AMX_Op<"tile_store"> { VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val); let extraClassDeclaration = [{ MemRefType getMemRefType() { - return base().getType().cast(); + return getBase().getType().cast(); } VectorType getVectorType() { - return val().getType().cast(); + return getVal().getType().cast(); } }]; let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` " @@ -194,13 +194,13 @@ def TileMulFOp : AMX_Op<"tile_mulf", [NoSideEffect, AllTypesMatch<["acc", "res"] let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res); let extraClassDeclaration = [{ VectorType getLhsVectorType() { - return lhs().getType().cast(); + return getLhs().getType().cast(); } VectorType getRhsVectorType() { - return rhs().getType().cast(); + return getRhs().getType().cast(); } VectorType getVectorType() { - return res().getType().cast(); + return getRes().getType().cast(); } }]; let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` " @@ -235,13 +235,13 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"] let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res); let extraClassDeclaration = [{ VectorType getLhsVectorType() { - return lhs().getType().cast(); + return getLhs().getType().cast(); } VectorType getRhsVectorType() { - return rhs().getType().cast(); + return getRhs().getType().cast(); } VectorType getVectorType() { - return res().getType().cast(); + return getRes().getType().cast(); } }]; let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` " diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 7af6fb447f0e..bd9ab3545ccb 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -27,7 +27,7 @@ class Tensor_OpWithOffsetSizesAndStrides { code extraBaseClassDeclaration = [{ /// Returns the dynamic sizes for this subview operation if specified. - ::mlir::Operation::operand_range getDynamicSizes() { return sizes(); } + ::mlir::Operation::operand_range getDynamicSizes() { return getSizes(); } /// Return the list of Range (i.e. offset, size, stride). Each /// Range entry contains either the dynamic value or a ConstantIndexOp @@ -266,7 +266,7 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", let extraClassDeclaration = extraBaseClassDeclaration # [{ /// Returns the type of the base tensor operand. RankedTensorType getSourceType() { - return source().getType().cast(); + return getSource().getType().cast(); } /// The result of an extract_slice is always a tensor. @@ -559,7 +559,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [ let extraClassDeclaration = extraBaseClassDeclaration # [{ /// Returns the type of the base tensor operand. RankedTensorType getSourceType() { - return source().getType().cast(); + return getSource().getType().cast(); } /// The result of a insert_slice is always a tensor. @@ -685,7 +685,7 @@ class Tensor_ReassociativeReshapeOp traits = []> : SmallVector getReassociationExprs(); SmallVector getReassociationIndices() { SmallVector reassociationIndices; - for (auto attr : reassociation()) + for (auto attr : getReassociation()) reassociationIndices.push_back(llvm::to_vector<2>( llvm::map_range(attr.cast(), [&](Attribute indexAttr) { return indexAttr.cast().getInt(); @@ -693,10 +693,10 @@ class Tensor_ReassociativeReshapeOp traits = []> : return reassociationIndices; }; RankedTensorType getSrcType() { - return src().getType().cast(); + return getSrc().getType().cast(); } RankedTensorType getResultType() { - return result().getType().cast(); + return getResult().getType().cast(); } }]; @@ -930,7 +930,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect, } RankedTensorType getSourceType() { - return source().getType().cast(); + return getSource().getType().cast(); } RankedTensorType getResultType() { return getResult().getType().cast(); @@ -965,10 +965,10 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect, return res; } SmallVector getMixedLowPad() { - return getMixedPadImpl(static_low(), low()); + return getMixedPadImpl(getStaticLow(), getLow()); } SmallVector getMixedHighPad() { - return getMixedPadImpl(static_high(), high()); + return getMixedPadImpl(getStaticHigh(), getHigh()); } // Return true if low padding is guaranteed to be 0. bool hasZeroLowPad() { diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 1e0d1adefbc3..2c9a1eaa2bf6 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -100,7 +100,7 @@ public: LogicalResult matchAndRewrite(AffineMinOp op, PatternRewriter &rewriter) const override { Value reduced = - lowerAffineMapMin(rewriter, op.getLoc(), op.map(), op.operands()); + lowerAffineMapMin(rewriter, op.getLoc(), op.getMap(), op.operands()); if (!reduced) return failure(); @@ -116,7 +116,7 @@ public: LogicalResult matchAndRewrite(AffineMaxOp op, PatternRewriter &rewriter) const override { Value reduced = - lowerAffineMapMax(rewriter, op.getLoc(), op.map(), op.operands()); + lowerAffineMapMax(rewriter, op.getLoc(), op.getMap(), op.operands()); if (!reduced) return failure(); @@ -156,7 +156,7 @@ public: auto scfForOp = rewriter.create(loc, lowerBound, upperBound, step, op.getIterOperands()); rewriter.eraseBlock(scfForOp.getBody()); - rewriter.inlineRegionBefore(op.region(), scfForOp.getRegion(), + rewriter.inlineRegionBefore(op.getRegion(), scfForOp.getRegion(), scfForOp.getRegion().end()); rewriter.replaceOp(op, scfForOp.getResults()); return success(); @@ -193,20 +193,20 @@ public: return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds"); upperBoundTuple.push_back(upper); } - steps.reserve(op.steps().size()); - for (int64_t step : op.steps()) + steps.reserve(op.getSteps().size()); + for (int64_t step : op.getSteps()) steps.push_back(rewriter.create(loc, step)); // Get the terminator op. Operation *affineParOpTerminator = op.getBody()->getTerminator(); scf::ParallelOp parOp; - if (op.results().empty()) { + if (op.getResults().empty()) { // Case with no reduction operations/return values. parOp = rewriter.create(loc, lowerBoundTuple, upperBoundTuple, steps, /*bodyBuilderFn=*/nullptr); rewriter.eraseBlock(parOp.getBody()); - rewriter.inlineRegionBefore(op.region(), parOp.getRegion(), + rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(), parOp.getRegion().end()); rewriter.replaceOp(op, parOp.getResults()); return success(); @@ -214,7 +214,7 @@ public: // Case with affine.parallel with reduction operations/return values. // scf.parallel handles the reduction operation differently unlike // affine.parallel. - ArrayRef reductions = op.reductions().getValue(); + ArrayRef reductions = op.getReductions().getValue(); for (auto pair : llvm::zip(reductions, op.getResultTypes())) { // For each of the reduction operations get the identity values for // initialization of the result values. @@ -234,7 +234,7 @@ public: // Copy the body of the affine.parallel op. rewriter.eraseBlock(parOp.getBody()); - rewriter.inlineRegionBefore(op.region(), parOp.getRegion(), + rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(), parOp.getRegion().end()); assert(reductions.size() == affineParOpTerminator->getNumOperands() && "Unequal number of reductions and operands."); @@ -299,13 +299,14 @@ public: : rewriter.create(loc, /*value=*/1, /*width=*/1); - bool hasElseRegion = !op.elseRegion().empty(); + bool hasElseRegion = !op.getElseRegion().empty(); auto ifOp = rewriter.create(loc, op.getResultTypes(), cond, hasElseRegion); - rewriter.inlineRegionBefore(op.thenRegion(), &ifOp.getThenRegion().back()); + rewriter.inlineRegionBefore(op.getThenRegion(), + &ifOp.getThenRegion().back()); rewriter.eraseBlock(&ifOp.getThenRegion().back()); if (hasElseRegion) { - rewriter.inlineRegionBefore(op.elseRegion(), + rewriter.inlineRegionBefore(op.getElseRegion(), &ifOp.getElseRegion().back()); rewriter.eraseBlock(&ifOp.getElseRegion().back()); } @@ -375,8 +376,8 @@ public: // Build memref.prefetch memref[expandedMap.results]. rewriter.replaceOpWithNewOp( - op, op.memref(), *resultOperands, op.isWrite(), op.localityHint(), - op.isDataCache()); + op, op.getMemref(), *resultOperands, op.getIsWrite(), + op.getLocalityHint(), op.getIsDataCache()); return success(); } }; diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 866fea818e8f..6af4c02b38ae 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -61,7 +61,7 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef pad, return tensor::createPadScalarOp(RankedTensorType::get(paddedShape, inputETy), input, padValue, lowIndices, highIndices, /*nofold=*/false, loc, rewriter) - .result(); + .getResult(); } static mlir::Value reifyConstantDim(Attribute attr, diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 293238e82b16..73eca29ec927 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -889,12 +889,12 @@ struct ForOpTensorCastFolder : public OpRewritePattern { // Must be a tensor.cast op pair with matching types. if (outgoingCastOp.getResult().getType() != - incomingCast.source().getType()) + incomingCast.getSource().getType()) continue; // Create a new ForOp with that iter operand replaced. auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand, - incomingCast.source()); + incomingCast.getSource()); // Insert outgoing cast and use it to replace the corresponding result. rewriter.setInsertionPointAfter(newForOp); @@ -902,7 +902,8 @@ struct ForOpTensorCastFolder : public OpRewritePattern { unsigned returnIdx = iterOpOperand.getOperandNumber() - op.getNumControlOperands(); replacements[returnIdx] = rewriter.create( - op.getLoc(), incomingCast.dest().getType(), replacements[returnIdx]); + op.getLoc(), incomingCast.getDest().getType(), + replacements[returnIdx]); rewriter.replaceOp(op, replacements); return success(); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index d9f2145d025c..897af9fcee6f 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -97,7 +97,7 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) { // Can fold if the source of cast has at least as much static information as // its results. return preservesStaticInformation(castOp.getType(), - castOp.source().getType()); + castOp.getSource().getType()); } /// Determines whether the tensor::CastOp casts to a more static version of the @@ -123,7 +123,7 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) { bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) { if (!castOp) return false; - return preservesStaticInformation(castOp.source().getType(), + return preservesStaticInformation(castOp.getSource().getType(), castOp.getType()); } @@ -250,13 +250,15 @@ struct TensorCastExtractSlice : public OpRewritePattern { tensorCast.getOperand().getDefiningOp(); if (!extractOperand || !canFoldIntoProducerOp(tensorCast) || - tensorCast.getType().getShape() == - tensorCast.source().getType().cast().getShape()) + tensorCast.getType().getShape() == tensorCast.getSource() + .getType() + .cast() + .getShape()) return failure(); SmallVector sizes = extractOperand.getMixedSizes(); auto dimMask = computeRankReductionMask( - extractFromI64ArrayAttr(extractOperand.static_sizes()), + extractFromI64ArrayAttr(extractOperand.getStaticSizes()), extractOperand.getType().getShape()); size_t dimIndex = 0; for (size_t i = 0, e = sizes.size(); i < e; i++) { @@ -270,7 +272,7 @@ struct TensorCastExtractSlice : public OpRewritePattern { rewriter.replaceOpWithNewOp( tensorCast, tensorCast.getType().cast(), - extractOperand.source(), extractOperand.getMixedOffsets(), sizes, + extractOperand.getSource(), extractOperand.getMixedOffsets(), sizes, extractOperand.getMixedStrides()); return success(); } @@ -295,7 +297,7 @@ void DimOp::build(OpBuilder &builder, OperationState &result, Value source, } Optional DimOp::getConstantIndex() { - if (auto constantOp = index().getDefiningOp()) + if (auto constantOp = getIndex().getDefiningOp()) return constantOp.getValue().cast().getInt(); return {}; } @@ -307,7 +309,7 @@ LogicalResult DimOp::verify() { return success(); // Check that constant index is not knowingly out of range. - auto type = source().getType(); + auto type = getSource().getType(); if (auto tensorType = type.dyn_cast()) { if (*index >= tensorType.getRank()) return emitOpError("index is out of range"); @@ -326,7 +328,7 @@ OpFoldResult DimOp::fold(ArrayRef operands) { return {}; // Folding for unranked types (UnrankedTensorType) is not supported. - auto tensorType = source().getType().dyn_cast(); + auto tensorType = getSource().getType().dyn_cast(); if (!tensorType) return {}; @@ -336,7 +338,7 @@ OpFoldResult DimOp::fold(ArrayRef operands) { return builder.getIndexAttr(tensorType.getShape()[index.getInt()]); } - Operation *definingOp = source().getDefiningOp(); + Operation *definingOp = getSource().getDefiningOp(); // Fold dim to the operand of tensor.generate. if (auto fromElements = dyn_cast_or_null(definingOp)) { @@ -347,7 +349,7 @@ OpFoldResult DimOp::fold(ArrayRef operands) { assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()])); // Find the operand of the fromElements that corresponds to this index. - auto dynExtents = fromElements.dynamicExtents().begin(); + auto dynExtents = fromElements.getDynamicExtents().begin(); for (auto dim : resultType.getShape().take_front(index.getInt())) if (ShapedType::isDynamic(dim)) dynExtents++; @@ -381,11 +383,11 @@ struct DimOfCastOp : public OpRewritePattern { LogicalResult matchAndRewrite(DimOp dimOp, PatternRewriter &rewriter) const override { - auto castOp = dimOp.source().getDefiningOp(); + auto castOp = dimOp.getSource().getDefiningOp(); if (!castOp) return failure(); Value newSource = castOp.getOperand(); - rewriter.replaceOpWithNewOp(dimOp, newSource, dimOp.index()); + rewriter.replaceOpWithNewOp(dimOp, newSource, dimOp.getIndex()); return success(); } }; @@ -402,8 +404,8 @@ void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult ExtractOp::verify() { // Verify the # indices match if we have a ranked type. - if (auto tensorType = tensor().getType().dyn_cast()) - if (tensorType.getRank() != static_cast(indices().size())) + if (auto tensorType = getTensor().getType().dyn_cast()) + if (tensorType.getRank() != static_cast(getIndices().size())) return emitOpError("incorrect number of indices for extract_element"); return success(); @@ -425,7 +427,7 @@ OpFoldResult ExtractOp::fold(ArrayRef operands) { } // Fold extract(from_elements(...)). - if (auto fromElementsOp = tensor().getDefiningOp()) { + if (auto fromElementsOp = getTensor().getDefiningOp()) { auto tensorType = fromElementsOp.getType().cast(); auto rank = tensorType.getRank(); assert(static_cast(indices.size()) == tensorType.getRank() && @@ -439,10 +441,10 @@ OpFoldResult ExtractOp::fold(ArrayRef operands) { } // Prevent out of bounds accesses. This can happen in invalid code that will // never execute. - if (static_cast(fromElementsOp.elements().size()) <= flatIndex || + if (static_cast(fromElementsOp.getElements().size()) <= flatIndex || flatIndex < 0) return {}; - return fromElementsOp.elements()[flatIndex]; + return fromElementsOp.getElements()[flatIndex]; } // If this is an elements attribute, query the value at the given indices. @@ -503,14 +505,14 @@ struct ExtractElementFromIndexCast LogicalResult matchAndRewrite(tensor::ExtractOp extract, PatternRewriter &rewriter) const final { Location loc = extract.getLoc(); - auto indexCast = extract.tensor().getDefiningOp(); + auto indexCast = extract.getTensor().getDefiningOp(); if (!indexCast) return failure(); Type elementTy = getElementTypeOrSelf(indexCast.getIn()); auto newExtract = rewriter.create( - loc, elementTy, indexCast.getIn(), extract.indices()); + loc, elementTy, indexCast.getIn(), extract.getIndices()); rewriter.replaceOpWithNewOp(extract, extract.getType(), newExtract); @@ -532,8 +534,8 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult InsertOp::verify() { // Verify the # indices match if we have a ranked type. - if (auto destType = dest().getType().dyn_cast()) - if (destType.getRank() != static_cast(indices().size())) + if (auto destType = getDest().getType().dyn_cast()) + if (destType.getRank() != static_cast(getIndices().size())) return emitOpError("incorrect number of indices"); return success(); } @@ -581,16 +583,16 @@ LogicalResult GenerateOp::verify() { LogicalResult GenerateOp::verifyRegions() { RankedTensorType resultTy = getType().cast(); // Ensure that region arguments span the index space. - if (!llvm::all_of(body().getArgumentTypes(), + if (!llvm::all_of(getBody().getArgumentTypes(), [](Type ty) { return ty.isIndex(); })) return emitError("all body arguments must be index"); - if (body().getNumArguments() != resultTy.getRank()) + if (getBody().getNumArguments() != resultTy.getRank()) return emitError("must have one body argument per input dimension"); // Ensure that the region yields an element of the right type. - auto yieldOp = cast(body().getBlocks().front().getTerminator()); + auto yieldOp = cast(getBody().getBlocks().front().getTerminator()); - if (yieldOp.value().getType() != resultTy.getElementType()) + if (yieldOp.getValue().getType() != resultTy.getElementType()) return emitOpError( "body must be terminated with a `yield` operation of the tensor " "element type"); @@ -634,7 +636,7 @@ struct StaticTensorGenerate : public OpRewritePattern { SmallVector newOperands; SmallVector newShape; - auto operandsIt = tensorFromElements.dynamicExtents().begin(); + auto operandsIt = tensorFromElements.getDynamicExtents().begin(); for (int64_t dim : resultType.getShape()) { if (!ShapedType::isDynamic(dim)) { @@ -651,15 +653,15 @@ struct StaticTensorGenerate : public OpRewritePattern { operandsIt++; } - if (newOperands.size() == tensorFromElements.dynamicExtents().size()) + if (newOperands.size() == tensorFromElements.getDynamicExtents().size()) return failure(); auto loc = tensorFromElements.getLoc(); auto newOp = rewriter.create( loc, RankedTensorType::get(newShape, resultType.getElementType()), newOperands); - rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(), - newOp.body().begin()); + rewriter.inlineRegionBefore(tensorFromElements.getBody(), newOp.getBody(), + newOp.getBody().begin()); rewriter.replaceOpWithNewOp(tensorFromElements, resultType, newOp); return success(); @@ -682,19 +684,19 @@ struct ExtractFromTensorGenerate : public OpRewritePattern { LogicalResult matchAndRewrite(tensor::ExtractOp extract, PatternRewriter &rewriter) const final { - auto tensorFromElements = extract.tensor().getDefiningOp(); + auto tensorFromElements = extract.getTensor().getDefiningOp(); if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements)) return failure(); BlockAndValueMapping mapping; Block *body = &tensorFromElements.getBody().front(); - mapping.map(body->getArguments(), extract.indices()); + mapping.map(body->getArguments(), extract.getIndices()); for (auto &op : body->without_terminator()) rewriter.clone(op, mapping); auto yield = cast(body->getTerminator()); - rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value())); + rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue())); return success(); } }; @@ -712,12 +714,12 @@ struct ExtractFromTensorCast : public OpRewritePattern { LogicalResult matchAndRewrite(tensor::ExtractOp extract, PatternRewriter &rewriter) const final { - auto tensorCast = extract.tensor().getDefiningOp(); + auto tensorCast = extract.getTensor().getDefiningOp(); if (!tensorCast) return failure(); - rewriter.replaceOpWithNewOp(extract, tensorCast.source(), - extract.indices()); + rewriter.replaceOpWithNewOp( + extract, tensorCast.getSource(), extract.getIndices()); return success(); } }; @@ -756,14 +758,15 @@ static int64_t getNumElements(ShapedType type) { } LogicalResult ReshapeOp::verify() { - TensorType operandType = source().getType().cast(); - TensorType resultType = result().getType().cast(); + TensorType operandType = getSource().getType().cast(); + TensorType resultType = getResult().getType().cast(); if (operandType.getElementType() != resultType.getElementType()) return emitOpError("element types of source and destination tensor " "types should be the same"); - int64_t shapeSize = shape().getType().cast().getDimSize(0); + int64_t shapeSize = + getShape().getType().cast().getDimSize(0); auto resultRankedType = resultType.dyn_cast(); auto operandRankedType = operandType.dyn_cast(); @@ -891,7 +894,7 @@ struct FoldReshapeWithConstant : OpRewritePattern { LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, PatternRewriter &rewriter) const override { DenseElementsAttr attr; - if (!matchPattern(reshapeOp.src(), m_Constant(&attr))) + if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr))) return failure(); if (!attr || !attr.isSplat()) return failure(); @@ -910,7 +913,7 @@ struct FoldReshapeWithFromElements : OpRewritePattern { LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, PatternRewriter &rewriter) const override { auto fromElements = - reshapeOp.src().template getDefiningOp(); + reshapeOp.getSrc().template getDefiningOp(); if (!fromElements) return failure(); @@ -920,7 +923,7 @@ struct FoldReshapeWithFromElements : OpRewritePattern { return failure(); rewriter.replaceOpWithNewOp(reshapeOp, reshapeOp.getType(), - fromElements.elements()); + fromElements.getElements()); return success(); } }; @@ -1208,7 +1211,7 @@ public: })) return failure(); - auto castOp = sliceOp.source().getDefiningOp(); + auto castOp = sliceOp.getSource().getDefiningOp(); if (!castOp) return failure(); @@ -1221,9 +1224,9 @@ public: sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); Value newSlice = rewriter.create( - sliceOp.getLoc(), resultType, castOp.source(), sliceOp.offsets(), - sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(), - sliceOp.static_sizes(), sliceOp.static_strides()); + sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(), + sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(), + sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); rewriter.replaceOpWithNewOp(sliceOp, sliceOp.getType(), newSlice); return success(); @@ -1277,7 +1280,7 @@ public: LogicalResult matchAndRewrite(ExtractSliceOp op, PatternRewriter &rewriter) const override { DenseElementsAttr attr; - if (!matchPattern(op.source(), m_Constant(&attr))) + if (!matchPattern(op.getSource(), m_Constant(&attr))) return failure(); // A constant splat is handled by fold(). @@ -1285,8 +1288,8 @@ public: return failure(); // Dynamic result shape is not supported. - auto sourceType = op.source().getType().cast(); - auto resultType = op.result().getType().cast(); + auto sourceType = op.getSource().getType().cast(); + auto resultType = op.getResult().getType().cast(); if (!sourceType.hasStaticShape() || !resultType.hasStaticShape()) return failure(); @@ -1299,13 +1302,13 @@ public: return failure(); // Check if there are any dynamic parts, which are not supported. - auto offsets = extractFromI64ArrayAttr(op.static_offsets()); + auto offsets = extractFromI64ArrayAttr(op.getStaticOffsets()); if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset)) return failure(); - auto sizes = extractFromI64ArrayAttr(op.static_sizes()); + auto sizes = extractFromI64ArrayAttr(op.getStaticSizes()); if (llvm::is_contained(sizes, ShapedType::kDynamicSize)) return failure(); - auto strides = extractFromI64ArrayAttr(op.static_strides()); + auto strides = extractFromI64ArrayAttr(op.getStaticStrides()); if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset)) return failure(); @@ -1414,25 +1417,25 @@ foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, // TODO: This only checks the immediate producer; extend to go up the // insert/extract chain if the slices are disjoint. static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) { - auto insertOp = extractOp.source().getDefiningOp(); + auto insertOp = extractOp.getSource().getDefiningOp(); auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; - if (insertOp && insertOp.source().getType() == extractOp.getType() && + if (insertOp && insertOp.getSource().getType() == extractOp.getType() && insertOp.isSameAs(extractOp, isSame)) - return insertOp.source(); + return insertOp.getSource(); return {}; } OpFoldResult ExtractSliceOp::fold(ArrayRef operands) { if (auto splat = operands[0].dyn_cast_or_null()) { - auto resultType = result().getType().cast(); + auto resultType = getResult().getType().cast(); if (resultType.hasStaticShape()) return splat.resizeSplat(resultType); } if (getSourceType() == getType() && succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) - return this->source(); + return this->getSource(); if (Value slice = foldExtractAfterInsertSlice(*this)) return slice; @@ -1518,8 +1521,8 @@ verifyInsertSliceOp(ShapedType srcType, ShapedType dstType, LogicalResult InsertSliceOp::verify() { ShapedType expectedType; auto result = - verifyInsertSliceOp(getSourceType(), getType(), static_offsets(), - static_sizes(), static_strides(), &expectedType); + verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(), + getStaticSizes(), getStaticStrides(), &expectedType); return produceSliceErrorMsg(result, *this, expectedType); } @@ -1539,15 +1542,15 @@ LogicalResult InsertSliceOp::verify() { /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1] /// ``` static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) { - auto prevInsertOp = insertOp.dest().getDefiningOp(); + auto prevInsertOp = insertOp.getDest().getDefiningOp(); auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; if (!prevInsertOp || - prevInsertOp.source().getType() != insertOp.source().getType() || + prevInsertOp.getSource().getType() != insertOp.getSource().getType() || !prevInsertOp.isSameAs(insertOp, isSame)) return failure(); - insertOp.destMutable().assign(prevInsertOp.dest()); + insertOp.getDestMutable().assign(prevInsertOp.getDest()); return success(); } @@ -1555,7 +1558,7 @@ OpFoldResult InsertSliceOp::fold(ArrayRef) { if (getSourceType().hasStaticShape() && getType().hasStaticShape() && getSourceType() == getType() && succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) - return this->source(); + return this->getSource(); if (succeeded(foldInsertAfterInsertSlice(*this))) return getResult(); return OpFoldResult(); @@ -1566,7 +1569,7 @@ LogicalResult InsertSliceOp::reifyResultShapes( reifiedReturnShapes.resize(1, SmallVector(getType().getRank())); for (auto dim : llvm::seq(0, getType().getRank())) { reifiedReturnShapes[0][dim] = - builder.createOrFold(getLoc(), dest(), dim); + builder.createOrFold(getLoc(), getDest(), dim); } return success(); } @@ -1600,13 +1603,13 @@ public: auto sourceType = ExtractSliceOp::inferRankReducedResultType( insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(), mixedOffsets, mixedSizes, mixedStrides); - Value toInsert = insertSliceOp.source(); + Value toInsert = insertSliceOp.getSource(); if (sourceType != insertSliceOp.getSourceType()) toInsert = rewriter.create(insertSliceOp.getLoc(), sourceType, toInsert); rewriter.replaceOpWithNewOp( - insertSliceOp, toInsert, insertSliceOp.dest(), mixedOffsets, mixedSizes, - mixedStrides); + insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets, + mixedSizes, mixedStrides); return success(); } }; @@ -1643,22 +1646,23 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern { auto castOp = v.getDefiningOp(); if (!castOp || !canFoldIntoConsumerOp(castOp)) return llvm::None; - return castOp.source(); + return castOp.getSource(); }; Optional sourceCastSource = - getSourceOfCastOp(insertSliceOp.source()); - Optional destCastSource = getSourceOfCastOp(insertSliceOp.dest()); + getSourceOfCastOp(insertSliceOp.getSource()); + Optional destCastSource = getSourceOfCastOp(insertSliceOp.getDest()); if (!sourceCastSource && !destCastSource) return failure(); - auto src = (sourceCastSource ? *sourceCastSource : insertSliceOp.source()); - auto dst = (destCastSource ? *destCastSource : insertSliceOp.dest()); + auto src = + (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource()); + auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest()); auto srcType = src.getType().cast(); auto dstType = dst.getType().cast(); - if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.static_offsets(), - insertSliceOp.static_sizes(), - insertSliceOp.static_strides()) != + if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(), + insertSliceOp.getStaticSizes(), + insertSliceOp.getStaticStrides()) != SliceVerificationResult::Success) return failure(); @@ -1724,9 +1728,9 @@ struct InsertSliceOpSourceCastInserter final // 3) Cast-compatible with srcType. // Insert the cast. Value cast = rewriter.create( - insertSliceOp.getLoc(), newSrcType, insertSliceOp.source()); + insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource()); rewriter.replaceOpWithNewOp( - insertSliceOp, cast, insertSliceOp.dest(), + insertSliceOp, cast, insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); return success(); @@ -1781,11 +1785,11 @@ ParseResult parseInferType(OpAsmParser &parser, } LogicalResult PadOp::verify() { - auto sourceType = source().getType().cast(); - auto resultType = result().getType().cast(); - auto expectedType = - PadOp::inferResultType(sourceType, extractFromI64ArrayAttr(static_low()), - extractFromI64ArrayAttr(static_high())); + auto sourceType = getSource().getType().cast(); + auto resultType = getResult().getType().cast(); + auto expectedType = PadOp::inferResultType( + sourceType, extractFromI64ArrayAttr(getStaticLow()), + extractFromI64ArrayAttr(getStaticHigh())); for (int i = 0, e = sourceType.getRank(); i < e; ++i) { if (resultType.getDimSize(i) == expectedType.getDimSize(i)) continue; @@ -1801,7 +1805,7 @@ LogicalResult PadOp::verify() { LogicalResult PadOp::verifyRegions() { auto ®ion = getRegion(); - unsigned rank = result().getType().cast().getRank(); + unsigned rank = getResult().getType().cast().getRank(); Block &block = region.front(); if (block.getNumArguments() != rank) return emitError("expected the block to have ") << rank << " arguments"; @@ -1815,7 +1819,7 @@ LogicalResult PadOp::verifyRegions() { // Ensure that the region yields an element of the right type. auto yieldOp = llvm::cast(block.getTerminator()); - if (yieldOp.value().getType() != + if (yieldOp.getValue().getType() != getType().cast().getElementType()) return emitOpError("expected yield type to match shape element type"); @@ -1919,10 +1923,11 @@ struct FoldStaticZeroPadding : public OpRewritePattern { PatternRewriter &rewriter) const override { if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad()) return failure(); - if (padTensorOp.nofold()) + if (padTensorOp.getNofold()) return failure(); rewriter.replaceOpWithNewOp( - padTensorOp, padTensorOp.result().getType(), padTensorOp.source()); + padTensorOp, padTensorOp.getResult().getType(), + padTensorOp.getSource()); return success(); } }; @@ -1933,25 +1938,26 @@ struct FoldSourceTensorCast : public OpRewritePattern { LogicalResult matchAndRewrite(PadOp padTensorOp, PatternRewriter &rewriter) const override { - auto castOp = padTensorOp.source().getDefiningOp(); + auto castOp = padTensorOp.getSource().getDefiningOp(); if (!tensor::canFoldIntoConsumerOp(castOp)) return failure(); auto newResultType = PadOp::inferResultType( - castOp.source().getType().cast(), - extractFromI64ArrayAttr(padTensorOp.static_low()), - extractFromI64ArrayAttr(padTensorOp.static_high()), + castOp.getSource().getType().cast(), + extractFromI64ArrayAttr(padTensorOp.getStaticLow()), + extractFromI64ArrayAttr(padTensorOp.getStaticHigh()), padTensorOp.getResultType().getShape()); if (newResultType == padTensorOp.getResultType()) { rewriter.updateRootInPlace(padTensorOp, [&]() { - padTensorOp.sourceMutable().assign(castOp.source()); + padTensorOp.getSourceMutable().assign(castOp.getSource()); }); } else { auto newOp = rewriter.create( - padTensorOp->getLoc(), newResultType, padTensorOp.source(), - padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(), - padTensorOp.static_high(), padTensorOp.nofold()); + padTensorOp->getLoc(), newResultType, padTensorOp.getSource(), + padTensorOp.getLow(), padTensorOp.getHigh(), + padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(), + padTensorOp.getNofold()); BlockAndValueMapping mapper; padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper); @@ -1969,25 +1975,25 @@ struct FoldTargetTensorCast : public OpRewritePattern { LogicalResult matchAndRewrite(PadOp padTensorOp, PatternRewriter &rewriter) const override { - if (!padTensorOp.result().hasOneUse()) + if (!padTensorOp.getResult().hasOneUse()) return failure(); auto tensorCastOp = dyn_cast(*padTensorOp->getUsers().begin()); if (!tensorCastOp) return failure(); - if (!tensor::preservesStaticInformation(padTensorOp.result().getType(), - tensorCastOp.dest().getType())) + if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(), + tensorCastOp.getDest().getType())) return failure(); auto replacementOp = rewriter.create( - padTensorOp.getLoc(), tensorCastOp.dest().getType(), - padTensorOp.source(), padTensorOp.low(), padTensorOp.high(), - padTensorOp.static_low(), padTensorOp.static_high(), - padTensorOp.nofold()); - replacementOp.region().takeBody(padTensorOp.region()); + padTensorOp.getLoc(), tensorCastOp.getDest().getType(), + padTensorOp.getSource(), padTensorOp.getLow(), padTensorOp.getHigh(), + padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(), + padTensorOp.getNofold()); + replacementOp.getRegion().takeBody(padTensorOp.getRegion()); - rewriter.replaceOp(padTensorOp, replacementOp.result()); - rewriter.replaceOp(tensorCastOp, replacementOp.result()); + rewriter.replaceOp(padTensorOp, replacementOp.getResult()); + rewriter.replaceOp(tensorCastOp, replacementOp.getResult()); return success(); } }; @@ -2031,13 +2037,13 @@ struct FoldOrthogonalPaddings : public OpRewritePattern { LogicalResult matchAndRewrite(PadOp padOp, PatternRewriter &rewriter) const override { - auto innerSliceOp = padOp.source().getDefiningOp(); + auto innerSliceOp = padOp.getSource().getDefiningOp(); if (!innerSliceOp) return failure(); - auto outerPadOp = innerSliceOp.source().getDefiningOp(); - if (!outerPadOp || outerPadOp.nofold()) + auto outerPadOp = innerSliceOp.getSource().getDefiningOp(); + if (!outerPadOp || outerPadOp.getNofold()) return failure(); - auto outerSliceOp = outerPadOp.source().getDefiningOp(); + auto outerSliceOp = outerPadOp.getSource().getDefiningOp(); if (!outerSliceOp) return failure(); @@ -2136,11 +2142,11 @@ struct FoldOrthogonalPaddings : public OpRewritePattern { // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs the // two paddings in one step. auto newSliceOp = rewriter.create( - padOp.getLoc(), outerSliceOp.source(), newOffsets, newSizes, + padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes, innerSliceOp.getMixedStrides()); auto newPadOp = rewriter.create( padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(), - padOp.getMixedLowPad(), newHighPad, padOp.nofold()); + padOp.getMixedLowPad(), newHighPad, padOp.getNofold()); rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(), newPadOp.getRegion().begin()); rewriter.replaceOp(padOp, newPadOp.getResult()); @@ -2169,7 +2175,7 @@ Value PadOp::getConstantPaddingValue() { auto yieldOp = dyn_cast(getRegion().front().getTerminator()); if (!yieldOp) return {}; - Value padValue = yieldOp.value(); + Value padValue = yieldOp.getValue(); // Check if yield value is a constant. if (matchPattern(padValue, m_Constant())) return padValue; @@ -2182,8 +2188,8 @@ Value PadOp::getConstantPaddingValue() { OpFoldResult PadOp::fold(ArrayRef) { if (getResultType().hasStaticShape() && getResultType() == getSourceType() && - !nofold()) - return source(); + !getNofold()) + return getSource(); return {}; } diff --git a/mlir/lib/Dialect/Tensor/Transforms/SplitPadding.cpp b/mlir/lib/Dialect/Tensor/Transforms/SplitPadding.cpp index 79c92466695f..eeca328493d9 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/SplitPadding.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/SplitPadding.cpp @@ -75,7 +75,7 @@ struct SplitPadding final : public OpRewritePattern { // Build the scf.if op itself. For the "then" branch, we can elide the // padding. For the "else" branch, we retain the clone op. auto thenBuilder = [&padOp](OpBuilder &builder, Location loc) { - builder.create(loc, padOp.source()); + builder.create(loc, padOp.getSource()); }; auto elseBuilder = [&padOp](OpBuilder &builder, Location loc) { Operation *newOp = builder.clone(*padOp);