forked from OSchip/llvm-project
[mlir] Flip more uses to prefixed accessor form (NFC).
Try to keep the final flip small. Need to flip MemRef as there are many templated cases with it and Tensor.
This commit is contained in:
parent
ca05cc2064
commit
2d70eff802
|
@ -97,7 +97,7 @@ def TileZeroOp : AMX_Op<"tile_zero", [NoSideEffect]> {
|
|||
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
|
||||
let extraClassDeclaration = [{
|
||||
VectorType getVectorType() {
|
||||
return res().getType().cast<VectorType>();
|
||||
return getRes().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "attr-dict `:` type($res)";
|
||||
|
@ -128,10 +128,10 @@ def TileLoadOp : AMX_Op<"tile_load", [NoSideEffect]> {
|
|||
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
return base().getType().cast<MemRefType>();
|
||||
return getBase().getType().cast<MemRefType>();
|
||||
}
|
||||
VectorType getVectorType() {
|
||||
return res().getType().cast<VectorType>();
|
||||
return getRes().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
|
||||
|
@ -158,10 +158,10 @@ def TileStoreOp : AMX_Op<"tile_store"> {
|
|||
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
return base().getType().cast<MemRefType>();
|
||||
return getBase().getType().cast<MemRefType>();
|
||||
}
|
||||
VectorType getVectorType() {
|
||||
return val().getType().cast<VectorType>();
|
||||
return getVal().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
|
||||
|
@ -194,13 +194,13 @@ def TileMulFOp : AMX_Op<"tile_mulf", [NoSideEffect, AllTypesMatch<["acc", "res"]
|
|||
let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
|
||||
let extraClassDeclaration = [{
|
||||
VectorType getLhsVectorType() {
|
||||
return lhs().getType().cast<VectorType>();
|
||||
return getLhs().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getRhsVectorType() {
|
||||
return rhs().getType().cast<VectorType>();
|
||||
return getRhs().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getVectorType() {
|
||||
return res().getType().cast<VectorType>();
|
||||
return getRes().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
|
||||
|
@ -235,13 +235,13 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
|
|||
let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
|
||||
let extraClassDeclaration = [{
|
||||
VectorType getLhsVectorType() {
|
||||
return lhs().getType().cast<VectorType>();
|
||||
return getLhs().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getRhsVectorType() {
|
||||
return rhs().getType().cast<VectorType>();
|
||||
return getRhs().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getVectorType() {
|
||||
return res().getType().cast<VectorType>();
|
||||
return getRes().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
|
||||
|
|
|
@ -27,7 +27,7 @@ class Tensor_OpWithOffsetSizesAndStrides<string mnemonic,
|
|||
: Tensor_Op<mnemonic, traits> {
|
||||
code extraBaseClassDeclaration = [{
|
||||
/// Returns the dynamic sizes for this subview operation if specified.
|
||||
::mlir::Operation::operand_range getDynamicSizes() { return sizes(); }
|
||||
::mlir::Operation::operand_range getDynamicSizes() { return getSizes(); }
|
||||
|
||||
/// Return the list of Range (i.e. offset, size, stride). Each
|
||||
/// Range entry contains either the dynamic value or a ConstantIndexOp
|
||||
|
@ -266,7 +266,7 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
|
|||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
/// Returns the type of the base tensor operand.
|
||||
RankedTensorType getSourceType() {
|
||||
return source().getType().cast<RankedTensorType>();
|
||||
return getSource().getType().cast<RankedTensorType>();
|
||||
}
|
||||
|
||||
/// The result of an extract_slice is always a tensor.
|
||||
|
@ -559,7 +559,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
|
|||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
/// Returns the type of the base tensor operand.
|
||||
RankedTensorType getSourceType() {
|
||||
return source().getType().cast<RankedTensorType>();
|
||||
return getSource().getType().cast<RankedTensorType>();
|
||||
}
|
||||
|
||||
/// The result of a insert_slice is always a tensor.
|
||||
|
@ -685,7 +685,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
|
|||
SmallVector<ReassociationExprs, 4> getReassociationExprs();
|
||||
SmallVector<ReassociationIndices, 4> getReassociationIndices() {
|
||||
SmallVector<ReassociationIndices, 4> reassociationIndices;
|
||||
for (auto attr : reassociation())
|
||||
for (auto attr : getReassociation())
|
||||
reassociationIndices.push_back(llvm::to_vector<2>(
|
||||
llvm::map_range(attr.cast<ArrayAttr>(), [&](Attribute indexAttr) {
|
||||
return indexAttr.cast<IntegerAttr>().getInt();
|
||||
|
@ -693,10 +693,10 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
|
|||
return reassociationIndices;
|
||||
};
|
||||
RankedTensorType getSrcType() {
|
||||
return src().getType().cast<RankedTensorType>();
|
||||
return getSrc().getType().cast<RankedTensorType>();
|
||||
}
|
||||
RankedTensorType getResultType() {
|
||||
return result().getType().cast<RankedTensorType>();
|
||||
return getResult().getType().cast<RankedTensorType>();
|
||||
}
|
||||
}];
|
||||
|
||||
|
@ -930,7 +930,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
|
|||
}
|
||||
|
||||
RankedTensorType getSourceType() {
|
||||
return source().getType().cast<RankedTensorType>();
|
||||
return getSource().getType().cast<RankedTensorType>();
|
||||
}
|
||||
RankedTensorType getResultType() {
|
||||
return getResult().getType().cast<RankedTensorType>();
|
||||
|
@ -965,10 +965,10 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
|
|||
return res;
|
||||
}
|
||||
SmallVector<OpFoldResult> getMixedLowPad() {
|
||||
return getMixedPadImpl(static_low(), low());
|
||||
return getMixedPadImpl(getStaticLow(), getLow());
|
||||
}
|
||||
SmallVector<OpFoldResult> getMixedHighPad() {
|
||||
return getMixedPadImpl(static_high(), high());
|
||||
return getMixedPadImpl(getStaticHigh(), getHigh());
|
||||
}
|
||||
// Return true if low padding is guaranteed to be 0.
|
||||
bool hasZeroLowPad() {
|
||||
|
|
|
@ -100,7 +100,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AffineMinOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value reduced =
|
||||
lowerAffineMapMin(rewriter, op.getLoc(), op.map(), op.operands());
|
||||
lowerAffineMapMin(rewriter, op.getLoc(), op.getMap(), op.operands());
|
||||
if (!reduced)
|
||||
return failure();
|
||||
|
||||
|
@ -116,7 +116,7 @@ public:
|
|||
LogicalResult matchAndRewrite(AffineMaxOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value reduced =
|
||||
lowerAffineMapMax(rewriter, op.getLoc(), op.map(), op.operands());
|
||||
lowerAffineMapMax(rewriter, op.getLoc(), op.getMap(), op.operands());
|
||||
if (!reduced)
|
||||
return failure();
|
||||
|
||||
|
@ -156,7 +156,7 @@ public:
|
|||
auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound,
|
||||
step, op.getIterOperands());
|
||||
rewriter.eraseBlock(scfForOp.getBody());
|
||||
rewriter.inlineRegionBefore(op.region(), scfForOp.getRegion(),
|
||||
rewriter.inlineRegionBefore(op.getRegion(), scfForOp.getRegion(),
|
||||
scfForOp.getRegion().end());
|
||||
rewriter.replaceOp(op, scfForOp.getResults());
|
||||
return success();
|
||||
|
@ -193,20 +193,20 @@ public:
|
|||
return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds");
|
||||
upperBoundTuple.push_back(upper);
|
||||
}
|
||||
steps.reserve(op.steps().size());
|
||||
for (int64_t step : op.steps())
|
||||
steps.reserve(op.getSteps().size());
|
||||
for (int64_t step : op.getSteps())
|
||||
steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
|
||||
|
||||
// Get the terminator op.
|
||||
Operation *affineParOpTerminator = op.getBody()->getTerminator();
|
||||
scf::ParallelOp parOp;
|
||||
if (op.results().empty()) {
|
||||
if (op.getResults().empty()) {
|
||||
// Case with no reduction operations/return values.
|
||||
parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
|
||||
upperBoundTuple, steps,
|
||||
/*bodyBuilderFn=*/nullptr);
|
||||
rewriter.eraseBlock(parOp.getBody());
|
||||
rewriter.inlineRegionBefore(op.region(), parOp.getRegion(),
|
||||
rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(),
|
||||
parOp.getRegion().end());
|
||||
rewriter.replaceOp(op, parOp.getResults());
|
||||
return success();
|
||||
|
@ -214,7 +214,7 @@ public:
|
|||
// Case with affine.parallel with reduction operations/return values.
|
||||
// scf.parallel handles the reduction operation differently unlike
|
||||
// affine.parallel.
|
||||
ArrayRef<Attribute> reductions = op.reductions().getValue();
|
||||
ArrayRef<Attribute> reductions = op.getReductions().getValue();
|
||||
for (auto pair : llvm::zip(reductions, op.getResultTypes())) {
|
||||
// For each of the reduction operations get the identity values for
|
||||
// initialization of the result values.
|
||||
|
@ -234,7 +234,7 @@ public:
|
|||
|
||||
// Copy the body of the affine.parallel op.
|
||||
rewriter.eraseBlock(parOp.getBody());
|
||||
rewriter.inlineRegionBefore(op.region(), parOp.getRegion(),
|
||||
rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(),
|
||||
parOp.getRegion().end());
|
||||
assert(reductions.size() == affineParOpTerminator->getNumOperands() &&
|
||||
"Unequal number of reductions and operands.");
|
||||
|
@ -299,13 +299,14 @@ public:
|
|||
: rewriter.create<arith::ConstantIntOp>(loc, /*value=*/1,
|
||||
/*width=*/1);
|
||||
|
||||
bool hasElseRegion = !op.elseRegion().empty();
|
||||
bool hasElseRegion = !op.getElseRegion().empty();
|
||||
auto ifOp = rewriter.create<scf::IfOp>(loc, op.getResultTypes(), cond,
|
||||
hasElseRegion);
|
||||
rewriter.inlineRegionBefore(op.thenRegion(), &ifOp.getThenRegion().back());
|
||||
rewriter.inlineRegionBefore(op.getThenRegion(),
|
||||
&ifOp.getThenRegion().back());
|
||||
rewriter.eraseBlock(&ifOp.getThenRegion().back());
|
||||
if (hasElseRegion) {
|
||||
rewriter.inlineRegionBefore(op.elseRegion(),
|
||||
rewriter.inlineRegionBefore(op.getElseRegion(),
|
||||
&ifOp.getElseRegion().back());
|
||||
rewriter.eraseBlock(&ifOp.getElseRegion().back());
|
||||
}
|
||||
|
@ -375,8 +376,8 @@ public:
|
|||
|
||||
// Build memref.prefetch memref[expandedMap.results].
|
||||
rewriter.replaceOpWithNewOp<memref::PrefetchOp>(
|
||||
op, op.memref(), *resultOperands, op.isWrite(), op.localityHint(),
|
||||
op.isDataCache());
|
||||
op, op.getMemref(), *resultOperands, op.getIsWrite(),
|
||||
op.getLocalityHint(), op.getIsDataCache());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -61,7 +61,7 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
|
|||
return tensor::createPadScalarOp(RankedTensorType::get(paddedShape, inputETy),
|
||||
input, padValue, lowIndices, highIndices,
|
||||
/*nofold=*/false, loc, rewriter)
|
||||
.result();
|
||||
.getResult();
|
||||
}
|
||||
|
||||
static mlir::Value reifyConstantDim(Attribute attr,
|
||||
|
|
|
@ -889,12 +889,12 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
|
|||
|
||||
// Must be a tensor.cast op pair with matching types.
|
||||
if (outgoingCastOp.getResult().getType() !=
|
||||
incomingCast.source().getType())
|
||||
incomingCast.getSource().getType())
|
||||
continue;
|
||||
|
||||
// Create a new ForOp with that iter operand replaced.
|
||||
auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
|
||||
incomingCast.source());
|
||||
incomingCast.getSource());
|
||||
|
||||
// Insert outgoing cast and use it to replace the corresponding result.
|
||||
rewriter.setInsertionPointAfter(newForOp);
|
||||
|
@ -902,7 +902,8 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
|
|||
unsigned returnIdx =
|
||||
iterOpOperand.getOperandNumber() - op.getNumControlOperands();
|
||||
replacements[returnIdx] = rewriter.create<tensor::CastOp>(
|
||||
op.getLoc(), incomingCast.dest().getType(), replacements[returnIdx]);
|
||||
op.getLoc(), incomingCast.getDest().getType(),
|
||||
replacements[returnIdx]);
|
||||
rewriter.replaceOp(op, replacements);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -97,7 +97,7 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
|
|||
// Can fold if the source of cast has at least as much static information as
|
||||
// its results.
|
||||
return preservesStaticInformation(castOp.getType(),
|
||||
castOp.source().getType());
|
||||
castOp.getSource().getType());
|
||||
}
|
||||
|
||||
/// Determines whether the tensor::CastOp casts to a more static version of the
|
||||
|
@ -123,7 +123,7 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
|
|||
bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
|
||||
if (!castOp)
|
||||
return false;
|
||||
return preservesStaticInformation(castOp.source().getType(),
|
||||
return preservesStaticInformation(castOp.getSource().getType(),
|
||||
castOp.getType());
|
||||
}
|
||||
|
||||
|
@ -250,13 +250,15 @@ struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
|
|||
tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
|
||||
|
||||
if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
|
||||
tensorCast.getType().getShape() ==
|
||||
tensorCast.source().getType().cast<RankedTensorType>().getShape())
|
||||
tensorCast.getType().getShape() == tensorCast.getSource()
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getShape())
|
||||
return failure();
|
||||
|
||||
SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
|
||||
auto dimMask = computeRankReductionMask(
|
||||
extractFromI64ArrayAttr(extractOperand.static_sizes()),
|
||||
extractFromI64ArrayAttr(extractOperand.getStaticSizes()),
|
||||
extractOperand.getType().getShape());
|
||||
size_t dimIndex = 0;
|
||||
for (size_t i = 0, e = sizes.size(); i < e; i++) {
|
||||
|
@ -270,7 +272,7 @@ struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
|
|||
|
||||
rewriter.replaceOpWithNewOp<ExtractSliceOp>(
|
||||
tensorCast, tensorCast.getType().cast<RankedTensorType>(),
|
||||
extractOperand.source(), extractOperand.getMixedOffsets(), sizes,
|
||||
extractOperand.getSource(), extractOperand.getMixedOffsets(), sizes,
|
||||
extractOperand.getMixedStrides());
|
||||
return success();
|
||||
}
|
||||
|
@ -295,7 +297,7 @@ void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
|
|||
}
|
||||
|
||||
Optional<int64_t> DimOp::getConstantIndex() {
|
||||
if (auto constantOp = index().getDefiningOp<arith::ConstantOp>())
|
||||
if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
|
||||
return constantOp.getValue().cast<IntegerAttr>().getInt();
|
||||
return {};
|
||||
}
|
||||
|
@ -307,7 +309,7 @@ LogicalResult DimOp::verify() {
|
|||
return success();
|
||||
|
||||
// Check that constant index is not knowingly out of range.
|
||||
auto type = source().getType();
|
||||
auto type = getSource().getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
if (*index >= tensorType.getRank())
|
||||
return emitOpError("index is out of range");
|
||||
|
@ -326,7 +328,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
|||
return {};
|
||||
|
||||
// Folding for unranked types (UnrankedTensorType) is not supported.
|
||||
auto tensorType = source().getType().dyn_cast<RankedTensorType>();
|
||||
auto tensorType = getSource().getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorType)
|
||||
return {};
|
||||
|
||||
|
@ -336,7 +338,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
|||
return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
|
||||
}
|
||||
|
||||
Operation *definingOp = source().getDefiningOp();
|
||||
Operation *definingOp = getSource().getDefiningOp();
|
||||
|
||||
// Fold dim to the operand of tensor.generate.
|
||||
if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
|
||||
|
@ -347,7 +349,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
|||
assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
|
||||
|
||||
// Find the operand of the fromElements that corresponds to this index.
|
||||
auto dynExtents = fromElements.dynamicExtents().begin();
|
||||
auto dynExtents = fromElements.getDynamicExtents().begin();
|
||||
for (auto dim : resultType.getShape().take_front(index.getInt()))
|
||||
if (ShapedType::isDynamic(dim))
|
||||
dynExtents++;
|
||||
|
@ -381,11 +383,11 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(DimOp dimOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto castOp = dimOp.source().getDefiningOp<CastOp>();
|
||||
auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
|
||||
if (!castOp)
|
||||
return failure();
|
||||
Value newSource = castOp.getOperand();
|
||||
rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
|
||||
rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -402,8 +404,8 @@ void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
|
||||
LogicalResult ExtractOp::verify() {
|
||||
// Verify the # indices match if we have a ranked type.
|
||||
if (auto tensorType = tensor().getType().dyn_cast<RankedTensorType>())
|
||||
if (tensorType.getRank() != static_cast<int64_t>(indices().size()))
|
||||
if (auto tensorType = getTensor().getType().dyn_cast<RankedTensorType>())
|
||||
if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
|
||||
return emitOpError("incorrect number of indices for extract_element");
|
||||
|
||||
return success();
|
||||
|
@ -425,7 +427,7 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
|
|||
}
|
||||
|
||||
// Fold extract(from_elements(...)).
|
||||
if (auto fromElementsOp = tensor().getDefiningOp<FromElementsOp>()) {
|
||||
if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
|
||||
auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
|
||||
auto rank = tensorType.getRank();
|
||||
assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
|
||||
|
@ -439,10 +441,10 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
|
|||
}
|
||||
// Prevent out of bounds accesses. This can happen in invalid code that will
|
||||
// never execute.
|
||||
if (static_cast<int>(fromElementsOp.elements().size()) <= flatIndex ||
|
||||
if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
|
||||
flatIndex < 0)
|
||||
return {};
|
||||
return fromElementsOp.elements()[flatIndex];
|
||||
return fromElementsOp.getElements()[flatIndex];
|
||||
}
|
||||
|
||||
// If this is an elements attribute, query the value at the given indices.
|
||||
|
@ -503,14 +505,14 @@ struct ExtractElementFromIndexCast
|
|||
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Location loc = extract.getLoc();
|
||||
auto indexCast = extract.tensor().getDefiningOp<arith::IndexCastOp>();
|
||||
auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
|
||||
if (!indexCast)
|
||||
return failure();
|
||||
|
||||
Type elementTy = getElementTypeOrSelf(indexCast.getIn());
|
||||
|
||||
auto newExtract = rewriter.create<tensor::ExtractOp>(
|
||||
loc, elementTy, indexCast.getIn(), extract.indices());
|
||||
loc, elementTy, indexCast.getIn(), extract.getIndices());
|
||||
|
||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
|
||||
newExtract);
|
||||
|
@ -532,8 +534,8 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
|
||||
LogicalResult InsertOp::verify() {
|
||||
// Verify the # indices match if we have a ranked type.
|
||||
if (auto destType = dest().getType().dyn_cast<RankedTensorType>())
|
||||
if (destType.getRank() != static_cast<int64_t>(indices().size()))
|
||||
if (auto destType = getDest().getType().dyn_cast<RankedTensorType>())
|
||||
if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
|
||||
return emitOpError("incorrect number of indices");
|
||||
return success();
|
||||
}
|
||||
|
@ -581,16 +583,16 @@ LogicalResult GenerateOp::verify() {
|
|||
LogicalResult GenerateOp::verifyRegions() {
|
||||
RankedTensorType resultTy = getType().cast<RankedTensorType>();
|
||||
// Ensure that region arguments span the index space.
|
||||
if (!llvm::all_of(body().getArgumentTypes(),
|
||||
if (!llvm::all_of(getBody().getArgumentTypes(),
|
||||
[](Type ty) { return ty.isIndex(); }))
|
||||
return emitError("all body arguments must be index");
|
||||
if (body().getNumArguments() != resultTy.getRank())
|
||||
if (getBody().getNumArguments() != resultTy.getRank())
|
||||
return emitError("must have one body argument per input dimension");
|
||||
|
||||
// Ensure that the region yields an element of the right type.
|
||||
auto yieldOp = cast<YieldOp>(body().getBlocks().front().getTerminator());
|
||||
auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
|
||||
|
||||
if (yieldOp.value().getType() != resultTy.getElementType())
|
||||
if (yieldOp.getValue().getType() != resultTy.getElementType())
|
||||
return emitOpError(
|
||||
"body must be terminated with a `yield` operation of the tensor "
|
||||
"element type");
|
||||
|
@ -634,7 +636,7 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
|
|||
|
||||
SmallVector<Value, 4> newOperands;
|
||||
SmallVector<int64_t, 4> newShape;
|
||||
auto operandsIt = tensorFromElements.dynamicExtents().begin();
|
||||
auto operandsIt = tensorFromElements.getDynamicExtents().begin();
|
||||
|
||||
for (int64_t dim : resultType.getShape()) {
|
||||
if (!ShapedType::isDynamic(dim)) {
|
||||
|
@ -651,15 +653,15 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
|
|||
operandsIt++;
|
||||
}
|
||||
|
||||
if (newOperands.size() == tensorFromElements.dynamicExtents().size())
|
||||
if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
|
||||
return failure();
|
||||
|
||||
auto loc = tensorFromElements.getLoc();
|
||||
auto newOp = rewriter.create<GenerateOp>(
|
||||
loc, RankedTensorType::get(newShape, resultType.getElementType()),
|
||||
newOperands);
|
||||
rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
|
||||
newOp.body().begin());
|
||||
rewriter.inlineRegionBefore(tensorFromElements.getBody(), newOp.getBody(),
|
||||
newOp.getBody().begin());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
|
||||
newOp);
|
||||
return success();
|
||||
|
@ -682,19 +684,19 @@ struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
|
||||
auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
|
||||
if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
|
||||
return failure();
|
||||
|
||||
BlockAndValueMapping mapping;
|
||||
Block *body = &tensorFromElements.getBody().front();
|
||||
mapping.map(body->getArguments(), extract.indices());
|
||||
mapping.map(body->getArguments(), extract.getIndices());
|
||||
for (auto &op : body->without_terminator())
|
||||
rewriter.clone(op, mapping);
|
||||
|
||||
auto yield = cast<YieldOp>(body->getTerminator());
|
||||
|
||||
rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
|
||||
rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -712,12 +714,12 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
|
||||
auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
|
||||
if (!tensorCast)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
|
||||
extract.indices());
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
|
||||
extract, tensorCast.getSource(), extract.getIndices());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -756,14 +758,15 @@ static int64_t getNumElements(ShapedType type) {
|
|||
}
|
||||
|
||||
LogicalResult ReshapeOp::verify() {
|
||||
TensorType operandType = source().getType().cast<TensorType>();
|
||||
TensorType resultType = result().getType().cast<TensorType>();
|
||||
TensorType operandType = getSource().getType().cast<TensorType>();
|
||||
TensorType resultType = getResult().getType().cast<TensorType>();
|
||||
|
||||
if (operandType.getElementType() != resultType.getElementType())
|
||||
return emitOpError("element types of source and destination tensor "
|
||||
"types should be the same");
|
||||
|
||||
int64_t shapeSize = shape().getType().cast<RankedTensorType>().getDimSize(0);
|
||||
int64_t shapeSize =
|
||||
getShape().getType().cast<RankedTensorType>().getDimSize(0);
|
||||
auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
|
||||
auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
|
||||
|
||||
|
@ -891,7 +894,7 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
|
|||
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
DenseElementsAttr attr;
|
||||
if (!matchPattern(reshapeOp.src(), m_Constant(&attr)))
|
||||
if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
|
||||
return failure();
|
||||
if (!attr || !attr.isSplat())
|
||||
return failure();
|
||||
|
@ -910,7 +913,7 @@ struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
|
|||
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto fromElements =
|
||||
reshapeOp.src().template getDefiningOp<FromElementsOp>();
|
||||
reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
|
||||
if (!fromElements)
|
||||
return failure();
|
||||
|
||||
|
@ -920,7 +923,7 @@ struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
|
|||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
|
||||
fromElements.elements());
|
||||
fromElements.getElements());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1208,7 +1211,7 @@ public:
|
|||
}))
|
||||
return failure();
|
||||
|
||||
auto castOp = sliceOp.source().getDefiningOp<tensor::CastOp>();
|
||||
auto castOp = sliceOp.getSource().getDefiningOp<tensor::CastOp>();
|
||||
if (!castOp)
|
||||
return failure();
|
||||
|
||||
|
@ -1221,9 +1224,9 @@ public:
|
|||
sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
|
||||
sliceOp.getMixedStrides());
|
||||
Value newSlice = rewriter.create<ExtractSliceOp>(
|
||||
sliceOp.getLoc(), resultType, castOp.source(), sliceOp.offsets(),
|
||||
sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(),
|
||||
sliceOp.static_sizes(), sliceOp.static_strides());
|
||||
sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(),
|
||||
sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
|
||||
sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
|
||||
newSlice);
|
||||
return success();
|
||||
|
@ -1277,7 +1280,7 @@ public:
|
|||
LogicalResult matchAndRewrite(ExtractSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
DenseElementsAttr attr;
|
||||
if (!matchPattern(op.source(), m_Constant(&attr)))
|
||||
if (!matchPattern(op.getSource(), m_Constant(&attr)))
|
||||
return failure();
|
||||
|
||||
// A constant splat is handled by fold().
|
||||
|
@ -1285,8 +1288,8 @@ public:
|
|||
return failure();
|
||||
|
||||
// Dynamic result shape is not supported.
|
||||
auto sourceType = op.source().getType().cast<ShapedType>();
|
||||
auto resultType = op.result().getType().cast<ShapedType>();
|
||||
auto sourceType = op.getSource().getType().cast<ShapedType>();
|
||||
auto resultType = op.getResult().getType().cast<ShapedType>();
|
||||
if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
|
@ -1299,13 +1302,13 @@ public:
|
|||
return failure();
|
||||
|
||||
// Check if there are any dynamic parts, which are not supported.
|
||||
auto offsets = extractFromI64ArrayAttr(op.static_offsets());
|
||||
auto offsets = extractFromI64ArrayAttr(op.getStaticOffsets());
|
||||
if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset))
|
||||
return failure();
|
||||
auto sizes = extractFromI64ArrayAttr(op.static_sizes());
|
||||
auto sizes = extractFromI64ArrayAttr(op.getStaticSizes());
|
||||
if (llvm::is_contained(sizes, ShapedType::kDynamicSize))
|
||||
return failure();
|
||||
auto strides = extractFromI64ArrayAttr(op.static_strides());
|
||||
auto strides = extractFromI64ArrayAttr(op.getStaticStrides());
|
||||
if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset))
|
||||
return failure();
|
||||
|
||||
|
@ -1414,25 +1417,25 @@ foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
|
|||
// TODO: This only checks the immediate producer; extend to go up the
|
||||
// insert/extract chain if the slices are disjoint.
|
||||
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
|
||||
auto insertOp = extractOp.source().getDefiningOp<InsertSliceOp>();
|
||||
auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
|
||||
|
||||
auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
|
||||
if (insertOp && insertOp.source().getType() == extractOp.getType() &&
|
||||
if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
|
||||
insertOp.isSameAs(extractOp, isSame))
|
||||
return insertOp.source();
|
||||
return insertOp.getSource();
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto splat = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
|
||||
auto resultType = result().getType().cast<ShapedType>();
|
||||
auto resultType = getResult().getType().cast<ShapedType>();
|
||||
if (resultType.hasStaticShape())
|
||||
return splat.resizeSplat(resultType);
|
||||
}
|
||||
if (getSourceType() == getType() &&
|
||||
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
|
||||
return this->source();
|
||||
return this->getSource();
|
||||
if (Value slice = foldExtractAfterInsertSlice(*this))
|
||||
return slice;
|
||||
|
||||
|
@ -1518,8 +1521,8 @@ verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
|
|||
LogicalResult InsertSliceOp::verify() {
|
||||
ShapedType expectedType;
|
||||
auto result =
|
||||
verifyInsertSliceOp(getSourceType(), getType(), static_offsets(),
|
||||
static_sizes(), static_strides(), &expectedType);
|
||||
verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
|
||||
getStaticSizes(), getStaticStrides(), &expectedType);
|
||||
return produceSliceErrorMsg(result, *this, expectedType);
|
||||
}
|
||||
|
||||
|
@ -1539,15 +1542,15 @@ LogicalResult InsertSliceOp::verify() {
|
|||
/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
|
||||
/// ```
|
||||
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
|
||||
auto prevInsertOp = insertOp.dest().getDefiningOp<InsertSliceOp>();
|
||||
auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
|
||||
|
||||
auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
|
||||
if (!prevInsertOp ||
|
||||
prevInsertOp.source().getType() != insertOp.source().getType() ||
|
||||
prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
|
||||
!prevInsertOp.isSameAs(insertOp, isSame))
|
||||
return failure();
|
||||
|
||||
insertOp.destMutable().assign(prevInsertOp.dest());
|
||||
insertOp.getDestMutable().assign(prevInsertOp.getDest());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1555,7 +1558,7 @@ OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
|
|||
if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
|
||||
getSourceType() == getType() &&
|
||||
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
|
||||
return this->source();
|
||||
return this->getSource();
|
||||
if (succeeded(foldInsertAfterInsertSlice(*this)))
|
||||
return getResult();
|
||||
return OpFoldResult();
|
||||
|
@ -1566,7 +1569,7 @@ LogicalResult InsertSliceOp::reifyResultShapes(
|
|||
reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
|
||||
for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
|
||||
reifiedReturnShapes[0][dim] =
|
||||
builder.createOrFold<tensor::DimOp>(getLoc(), dest(), dim);
|
||||
builder.createOrFold<tensor::DimOp>(getLoc(), getDest(), dim);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -1600,13 +1603,13 @@ public:
|
|||
auto sourceType = ExtractSliceOp::inferRankReducedResultType(
|
||||
insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(),
|
||||
mixedOffsets, mixedSizes, mixedStrides);
|
||||
Value toInsert = insertSliceOp.source();
|
||||
Value toInsert = insertSliceOp.getSource();
|
||||
if (sourceType != insertSliceOp.getSourceType())
|
||||
toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
|
||||
sourceType, toInsert);
|
||||
rewriter.replaceOpWithNewOp<InsertSliceOp>(
|
||||
insertSliceOp, toInsert, insertSliceOp.dest(), mixedOffsets, mixedSizes,
|
||||
mixedStrides);
|
||||
insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
|
||||
mixedSizes, mixedStrides);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1643,22 +1646,23 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
|
|||
auto castOp = v.getDefiningOp<tensor::CastOp>();
|
||||
if (!castOp || !canFoldIntoConsumerOp(castOp))
|
||||
return llvm::None;
|
||||
return castOp.source();
|
||||
return castOp.getSource();
|
||||
};
|
||||
Optional<Value> sourceCastSource =
|
||||
getSourceOfCastOp(insertSliceOp.source());
|
||||
Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.dest());
|
||||
getSourceOfCastOp(insertSliceOp.getSource());
|
||||
Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.getDest());
|
||||
if (!sourceCastSource && !destCastSource)
|
||||
return failure();
|
||||
|
||||
auto src = (sourceCastSource ? *sourceCastSource : insertSliceOp.source());
|
||||
auto dst = (destCastSource ? *destCastSource : insertSliceOp.dest());
|
||||
auto src =
|
||||
(sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
|
||||
auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
|
||||
|
||||
auto srcType = src.getType().cast<ShapedType>();
|
||||
auto dstType = dst.getType().cast<ShapedType>();
|
||||
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.static_offsets(),
|
||||
insertSliceOp.static_sizes(),
|
||||
insertSliceOp.static_strides()) !=
|
||||
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
|
||||
insertSliceOp.getStaticSizes(),
|
||||
insertSliceOp.getStaticStrides()) !=
|
||||
SliceVerificationResult::Success)
|
||||
return failure();
|
||||
|
||||
|
@ -1724,9 +1728,9 @@ struct InsertSliceOpSourceCastInserter final
|
|||
// 3) Cast-compatible with srcType.
|
||||
// Insert the cast.
|
||||
Value cast = rewriter.create<tensor::CastOp>(
|
||||
insertSliceOp.getLoc(), newSrcType, insertSliceOp.source());
|
||||
insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
|
||||
rewriter.replaceOpWithNewOp<InsertSliceOp>(
|
||||
insertSliceOp, cast, insertSliceOp.dest(),
|
||||
insertSliceOp, cast, insertSliceOp.getDest(),
|
||||
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
|
||||
insertSliceOp.getMixedStrides());
|
||||
return success();
|
||||
|
@ -1781,11 +1785,11 @@ ParseResult parseInferType(OpAsmParser &parser,
|
|||
}
|
||||
|
||||
LogicalResult PadOp::verify() {
|
||||
auto sourceType = source().getType().cast<RankedTensorType>();
|
||||
auto resultType = result().getType().cast<RankedTensorType>();
|
||||
auto expectedType =
|
||||
PadOp::inferResultType(sourceType, extractFromI64ArrayAttr(static_low()),
|
||||
extractFromI64ArrayAttr(static_high()));
|
||||
auto sourceType = getSource().getType().cast<RankedTensorType>();
|
||||
auto resultType = getResult().getType().cast<RankedTensorType>();
|
||||
auto expectedType = PadOp::inferResultType(
|
||||
sourceType, extractFromI64ArrayAttr(getStaticLow()),
|
||||
extractFromI64ArrayAttr(getStaticHigh()));
|
||||
for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
|
||||
if (resultType.getDimSize(i) == expectedType.getDimSize(i))
|
||||
continue;
|
||||
|
@ -1801,7 +1805,7 @@ LogicalResult PadOp::verify() {
|
|||
|
||||
LogicalResult PadOp::verifyRegions() {
|
||||
auto ®ion = getRegion();
|
||||
unsigned rank = result().getType().cast<RankedTensorType>().getRank();
|
||||
unsigned rank = getResult().getType().cast<RankedTensorType>().getRank();
|
||||
Block &block = region.front();
|
||||
if (block.getNumArguments() != rank)
|
||||
return emitError("expected the block to have ") << rank << " arguments";
|
||||
|
@ -1815,7 +1819,7 @@ LogicalResult PadOp::verifyRegions() {
|
|||
|
||||
// Ensure that the region yields an element of the right type.
|
||||
auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
|
||||
if (yieldOp.value().getType() !=
|
||||
if (yieldOp.getValue().getType() !=
|
||||
getType().cast<ShapedType>().getElementType())
|
||||
return emitOpError("expected yield type to match shape element type");
|
||||
|
||||
|
@ -1919,10 +1923,11 @@ struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
|
|||
PatternRewriter &rewriter) const override {
|
||||
if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
|
||||
return failure();
|
||||
if (padTensorOp.nofold())
|
||||
if (padTensorOp.getNofold())
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
padTensorOp, padTensorOp.result().getType(), padTensorOp.source());
|
||||
padTensorOp, padTensorOp.getResult().getType(),
|
||||
padTensorOp.getSource());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1933,25 +1938,26 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(PadOp padTensorOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto castOp = padTensorOp.source().getDefiningOp<tensor::CastOp>();
|
||||
auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
|
||||
if (!tensor::canFoldIntoConsumerOp(castOp))
|
||||
return failure();
|
||||
|
||||
auto newResultType = PadOp::inferResultType(
|
||||
castOp.source().getType().cast<RankedTensorType>(),
|
||||
extractFromI64ArrayAttr(padTensorOp.static_low()),
|
||||
extractFromI64ArrayAttr(padTensorOp.static_high()),
|
||||
castOp.getSource().getType().cast<RankedTensorType>(),
|
||||
extractFromI64ArrayAttr(padTensorOp.getStaticLow()),
|
||||
extractFromI64ArrayAttr(padTensorOp.getStaticHigh()),
|
||||
padTensorOp.getResultType().getShape());
|
||||
|
||||
if (newResultType == padTensorOp.getResultType()) {
|
||||
rewriter.updateRootInPlace(padTensorOp, [&]() {
|
||||
padTensorOp.sourceMutable().assign(castOp.source());
|
||||
padTensorOp.getSourceMutable().assign(castOp.getSource());
|
||||
});
|
||||
} else {
|
||||
auto newOp = rewriter.create<PadOp>(
|
||||
padTensorOp->getLoc(), newResultType, padTensorOp.source(),
|
||||
padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(),
|
||||
padTensorOp.static_high(), padTensorOp.nofold());
|
||||
padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
|
||||
padTensorOp.getLow(), padTensorOp.getHigh(),
|
||||
padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
|
||||
padTensorOp.getNofold());
|
||||
BlockAndValueMapping mapper;
|
||||
padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
|
||||
|
||||
|
@ -1969,25 +1975,25 @@ struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(PadOp padTensorOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!padTensorOp.result().hasOneUse())
|
||||
if (!padTensorOp.getResult().hasOneUse())
|
||||
return failure();
|
||||
auto tensorCastOp =
|
||||
dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
|
||||
if (!tensorCastOp)
|
||||
return failure();
|
||||
if (!tensor::preservesStaticInformation(padTensorOp.result().getType(),
|
||||
tensorCastOp.dest().getType()))
|
||||
if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
|
||||
tensorCastOp.getDest().getType()))
|
||||
return failure();
|
||||
|
||||
auto replacementOp = rewriter.create<PadOp>(
|
||||
padTensorOp.getLoc(), tensorCastOp.dest().getType(),
|
||||
padTensorOp.source(), padTensorOp.low(), padTensorOp.high(),
|
||||
padTensorOp.static_low(), padTensorOp.static_high(),
|
||||
padTensorOp.nofold());
|
||||
replacementOp.region().takeBody(padTensorOp.region());
|
||||
padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
|
||||
padTensorOp.getSource(), padTensorOp.getLow(), padTensorOp.getHigh(),
|
||||
padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
|
||||
padTensorOp.getNofold());
|
||||
replacementOp.getRegion().takeBody(padTensorOp.getRegion());
|
||||
|
||||
rewriter.replaceOp(padTensorOp, replacementOp.result());
|
||||
rewriter.replaceOp(tensorCastOp, replacementOp.result());
|
||||
rewriter.replaceOp(padTensorOp, replacementOp.getResult());
|
||||
rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -2031,13 +2037,13 @@ struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(PadOp padOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto innerSliceOp = padOp.source().getDefiningOp<ExtractSliceOp>();
|
||||
auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
|
||||
if (!innerSliceOp)
|
||||
return failure();
|
||||
auto outerPadOp = innerSliceOp.source().getDefiningOp<PadOp>();
|
||||
if (!outerPadOp || outerPadOp.nofold())
|
||||
auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
|
||||
if (!outerPadOp || outerPadOp.getNofold())
|
||||
return failure();
|
||||
auto outerSliceOp = outerPadOp.source().getDefiningOp<ExtractSliceOp>();
|
||||
auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
|
||||
if (!outerSliceOp)
|
||||
return failure();
|
||||
|
||||
|
@ -2136,11 +2142,11 @@ struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
|
|||
// Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs the
|
||||
// two paddings in one step.
|
||||
auto newSliceOp = rewriter.create<ExtractSliceOp>(
|
||||
padOp.getLoc(), outerSliceOp.source(), newOffsets, newSizes,
|
||||
padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
|
||||
innerSliceOp.getMixedStrides());
|
||||
auto newPadOp = rewriter.create<PadOp>(
|
||||
padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
|
||||
padOp.getMixedLowPad(), newHighPad, padOp.nofold());
|
||||
padOp.getMixedLowPad(), newHighPad, padOp.getNofold());
|
||||
rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
|
||||
newPadOp.getRegion().begin());
|
||||
rewriter.replaceOp(padOp, newPadOp.getResult());
|
||||
|
@ -2169,7 +2175,7 @@ Value PadOp::getConstantPaddingValue() {
|
|||
auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
|
||||
if (!yieldOp)
|
||||
return {};
|
||||
Value padValue = yieldOp.value();
|
||||
Value padValue = yieldOp.getValue();
|
||||
// Check if yield value is a constant.
|
||||
if (matchPattern(padValue, m_Constant()))
|
||||
return padValue;
|
||||
|
@ -2182,8 +2188,8 @@ Value PadOp::getConstantPaddingValue() {
|
|||
|
||||
OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
|
||||
if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
|
||||
!nofold())
|
||||
return source();
|
||||
!getNofold())
|
||||
return getSource();
|
||||
return {};
|
||||
}
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ struct SplitPadding final : public OpRewritePattern<tensor::PadOp> {
|
|||
// Build the scf.if op itself. For the "then" branch, we can elide the
|
||||
// padding. For the "else" branch, we retain the clone op.
|
||||
auto thenBuilder = [&padOp](OpBuilder &builder, Location loc) {
|
||||
builder.create<scf::YieldOp>(loc, padOp.source());
|
||||
builder.create<scf::YieldOp>(loc, padOp.getSource());
|
||||
};
|
||||
auto elseBuilder = [&padOp](OpBuilder &builder, Location loc) {
|
||||
Operation *newOp = builder.clone(*padOp);
|
||||
|
|
Loading…
Reference in New Issue