forked from OSchip/llvm-project
[mlir] Flip more uses to prefixed accessor form (NFC).
Try to keep the final flip small. Need to flip MemRef as there are many templated cases with it and Tensor.
This commit is contained in:
parent
ca05cc2064
commit
2d70eff802
|
@ -97,7 +97,7 @@ def TileZeroOp : AMX_Op<"tile_zero", [NoSideEffect]> {
|
||||||
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
|
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
VectorType getVectorType() {
|
VectorType getVectorType() {
|
||||||
return res().getType().cast<VectorType>();
|
return getRes().getType().cast<VectorType>();
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
let assemblyFormat = "attr-dict `:` type($res)";
|
let assemblyFormat = "attr-dict `:` type($res)";
|
||||||
|
@ -128,10 +128,10 @@ def TileLoadOp : AMX_Op<"tile_load", [NoSideEffect]> {
|
||||||
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
|
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
MemRefType getMemRefType() {
|
MemRefType getMemRefType() {
|
||||||
return base().getType().cast<MemRefType>();
|
return getBase().getType().cast<MemRefType>();
|
||||||
}
|
}
|
||||||
VectorType getVectorType() {
|
VectorType getVectorType() {
|
||||||
return res().getType().cast<VectorType>();
|
return getRes().getType().cast<VectorType>();
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
|
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
|
||||||
|
@ -158,10 +158,10 @@ def TileStoreOp : AMX_Op<"tile_store"> {
|
||||||
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
|
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
MemRefType getMemRefType() {
|
MemRefType getMemRefType() {
|
||||||
return base().getType().cast<MemRefType>();
|
return getBase().getType().cast<MemRefType>();
|
||||||
}
|
}
|
||||||
VectorType getVectorType() {
|
VectorType getVectorType() {
|
||||||
return val().getType().cast<VectorType>();
|
return getVal().getType().cast<VectorType>();
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
|
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
|
||||||
|
@ -194,13 +194,13 @@ def TileMulFOp : AMX_Op<"tile_mulf", [NoSideEffect, AllTypesMatch<["acc", "res"]
|
||||||
let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
|
let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
VectorType getLhsVectorType() {
|
VectorType getLhsVectorType() {
|
||||||
return lhs().getType().cast<VectorType>();
|
return getLhs().getType().cast<VectorType>();
|
||||||
}
|
}
|
||||||
VectorType getRhsVectorType() {
|
VectorType getRhsVectorType() {
|
||||||
return rhs().getType().cast<VectorType>();
|
return getRhs().getType().cast<VectorType>();
|
||||||
}
|
}
|
||||||
VectorType getVectorType() {
|
VectorType getVectorType() {
|
||||||
return res().getType().cast<VectorType>();
|
return getRes().getType().cast<VectorType>();
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
|
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
|
||||||
|
@ -235,13 +235,13 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
|
||||||
let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
|
let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
VectorType getLhsVectorType() {
|
VectorType getLhsVectorType() {
|
||||||
return lhs().getType().cast<VectorType>();
|
return getLhs().getType().cast<VectorType>();
|
||||||
}
|
}
|
||||||
VectorType getRhsVectorType() {
|
VectorType getRhsVectorType() {
|
||||||
return rhs().getType().cast<VectorType>();
|
return getRhs().getType().cast<VectorType>();
|
||||||
}
|
}
|
||||||
VectorType getVectorType() {
|
VectorType getVectorType() {
|
||||||
return res().getType().cast<VectorType>();
|
return getRes().getType().cast<VectorType>();
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
|
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
|
||||||
|
|
|
@ -27,7 +27,7 @@ class Tensor_OpWithOffsetSizesAndStrides<string mnemonic,
|
||||||
: Tensor_Op<mnemonic, traits> {
|
: Tensor_Op<mnemonic, traits> {
|
||||||
code extraBaseClassDeclaration = [{
|
code extraBaseClassDeclaration = [{
|
||||||
/// Returns the dynamic sizes for this subview operation if specified.
|
/// Returns the dynamic sizes for this subview operation if specified.
|
||||||
::mlir::Operation::operand_range getDynamicSizes() { return sizes(); }
|
::mlir::Operation::operand_range getDynamicSizes() { return getSizes(); }
|
||||||
|
|
||||||
/// Return the list of Range (i.e. offset, size, stride). Each
|
/// Return the list of Range (i.e. offset, size, stride). Each
|
||||||
/// Range entry contains either the dynamic value or a ConstantIndexOp
|
/// Range entry contains either the dynamic value or a ConstantIndexOp
|
||||||
|
@ -266,7 +266,7 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
|
||||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||||
/// Returns the type of the base tensor operand.
|
/// Returns the type of the base tensor operand.
|
||||||
RankedTensorType getSourceType() {
|
RankedTensorType getSourceType() {
|
||||||
return source().getType().cast<RankedTensorType>();
|
return getSource().getType().cast<RankedTensorType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The result of an extract_slice is always a tensor.
|
/// The result of an extract_slice is always a tensor.
|
||||||
|
@ -559,7 +559,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
|
||||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||||
/// Returns the type of the base tensor operand.
|
/// Returns the type of the base tensor operand.
|
||||||
RankedTensorType getSourceType() {
|
RankedTensorType getSourceType() {
|
||||||
return source().getType().cast<RankedTensorType>();
|
return getSource().getType().cast<RankedTensorType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The result of a insert_slice is always a tensor.
|
/// The result of a insert_slice is always a tensor.
|
||||||
|
@ -685,7 +685,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
|
||||||
SmallVector<ReassociationExprs, 4> getReassociationExprs();
|
SmallVector<ReassociationExprs, 4> getReassociationExprs();
|
||||||
SmallVector<ReassociationIndices, 4> getReassociationIndices() {
|
SmallVector<ReassociationIndices, 4> getReassociationIndices() {
|
||||||
SmallVector<ReassociationIndices, 4> reassociationIndices;
|
SmallVector<ReassociationIndices, 4> reassociationIndices;
|
||||||
for (auto attr : reassociation())
|
for (auto attr : getReassociation())
|
||||||
reassociationIndices.push_back(llvm::to_vector<2>(
|
reassociationIndices.push_back(llvm::to_vector<2>(
|
||||||
llvm::map_range(attr.cast<ArrayAttr>(), [&](Attribute indexAttr) {
|
llvm::map_range(attr.cast<ArrayAttr>(), [&](Attribute indexAttr) {
|
||||||
return indexAttr.cast<IntegerAttr>().getInt();
|
return indexAttr.cast<IntegerAttr>().getInt();
|
||||||
|
@ -693,10 +693,10 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
|
||||||
return reassociationIndices;
|
return reassociationIndices;
|
||||||
};
|
};
|
||||||
RankedTensorType getSrcType() {
|
RankedTensorType getSrcType() {
|
||||||
return src().getType().cast<RankedTensorType>();
|
return getSrc().getType().cast<RankedTensorType>();
|
||||||
}
|
}
|
||||||
RankedTensorType getResultType() {
|
RankedTensorType getResultType() {
|
||||||
return result().getType().cast<RankedTensorType>();
|
return getResult().getType().cast<RankedTensorType>();
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
@ -930,7 +930,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
|
||||||
}
|
}
|
||||||
|
|
||||||
RankedTensorType getSourceType() {
|
RankedTensorType getSourceType() {
|
||||||
return source().getType().cast<RankedTensorType>();
|
return getSource().getType().cast<RankedTensorType>();
|
||||||
}
|
}
|
||||||
RankedTensorType getResultType() {
|
RankedTensorType getResultType() {
|
||||||
return getResult().getType().cast<RankedTensorType>();
|
return getResult().getType().cast<RankedTensorType>();
|
||||||
|
@ -965,10 +965,10 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
SmallVector<OpFoldResult> getMixedLowPad() {
|
SmallVector<OpFoldResult> getMixedLowPad() {
|
||||||
return getMixedPadImpl(static_low(), low());
|
return getMixedPadImpl(getStaticLow(), getLow());
|
||||||
}
|
}
|
||||||
SmallVector<OpFoldResult> getMixedHighPad() {
|
SmallVector<OpFoldResult> getMixedHighPad() {
|
||||||
return getMixedPadImpl(static_high(), high());
|
return getMixedPadImpl(getStaticHigh(), getHigh());
|
||||||
}
|
}
|
||||||
// Return true if low padding is guaranteed to be 0.
|
// Return true if low padding is guaranteed to be 0.
|
||||||
bool hasZeroLowPad() {
|
bool hasZeroLowPad() {
|
||||||
|
|
|
@ -100,7 +100,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AffineMinOp op,
|
LogicalResult matchAndRewrite(AffineMinOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value reduced =
|
Value reduced =
|
||||||
lowerAffineMapMin(rewriter, op.getLoc(), op.map(), op.operands());
|
lowerAffineMapMin(rewriter, op.getLoc(), op.getMap(), op.operands());
|
||||||
if (!reduced)
|
if (!reduced)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -116,7 +116,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AffineMaxOp op,
|
LogicalResult matchAndRewrite(AffineMaxOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value reduced =
|
Value reduced =
|
||||||
lowerAffineMapMax(rewriter, op.getLoc(), op.map(), op.operands());
|
lowerAffineMapMax(rewriter, op.getLoc(), op.getMap(), op.operands());
|
||||||
if (!reduced)
|
if (!reduced)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -156,7 +156,7 @@ public:
|
||||||
auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound,
|
auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound,
|
||||||
step, op.getIterOperands());
|
step, op.getIterOperands());
|
||||||
rewriter.eraseBlock(scfForOp.getBody());
|
rewriter.eraseBlock(scfForOp.getBody());
|
||||||
rewriter.inlineRegionBefore(op.region(), scfForOp.getRegion(),
|
rewriter.inlineRegionBefore(op.getRegion(), scfForOp.getRegion(),
|
||||||
scfForOp.getRegion().end());
|
scfForOp.getRegion().end());
|
||||||
rewriter.replaceOp(op, scfForOp.getResults());
|
rewriter.replaceOp(op, scfForOp.getResults());
|
||||||
return success();
|
return success();
|
||||||
|
@ -193,20 +193,20 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds");
|
return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds");
|
||||||
upperBoundTuple.push_back(upper);
|
upperBoundTuple.push_back(upper);
|
||||||
}
|
}
|
||||||
steps.reserve(op.steps().size());
|
steps.reserve(op.getSteps().size());
|
||||||
for (int64_t step : op.steps())
|
for (int64_t step : op.getSteps())
|
||||||
steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
|
steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
|
||||||
|
|
||||||
// Get the terminator op.
|
// Get the terminator op.
|
||||||
Operation *affineParOpTerminator = op.getBody()->getTerminator();
|
Operation *affineParOpTerminator = op.getBody()->getTerminator();
|
||||||
scf::ParallelOp parOp;
|
scf::ParallelOp parOp;
|
||||||
if (op.results().empty()) {
|
if (op.getResults().empty()) {
|
||||||
// Case with no reduction operations/return values.
|
// Case with no reduction operations/return values.
|
||||||
parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
|
parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
|
||||||
upperBoundTuple, steps,
|
upperBoundTuple, steps,
|
||||||
/*bodyBuilderFn=*/nullptr);
|
/*bodyBuilderFn=*/nullptr);
|
||||||
rewriter.eraseBlock(parOp.getBody());
|
rewriter.eraseBlock(parOp.getBody());
|
||||||
rewriter.inlineRegionBefore(op.region(), parOp.getRegion(),
|
rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(),
|
||||||
parOp.getRegion().end());
|
parOp.getRegion().end());
|
||||||
rewriter.replaceOp(op, parOp.getResults());
|
rewriter.replaceOp(op, parOp.getResults());
|
||||||
return success();
|
return success();
|
||||||
|
@ -214,7 +214,7 @@ public:
|
||||||
// Case with affine.parallel with reduction operations/return values.
|
// Case with affine.parallel with reduction operations/return values.
|
||||||
// scf.parallel handles the reduction operation differently unlike
|
// scf.parallel handles the reduction operation differently unlike
|
||||||
// affine.parallel.
|
// affine.parallel.
|
||||||
ArrayRef<Attribute> reductions = op.reductions().getValue();
|
ArrayRef<Attribute> reductions = op.getReductions().getValue();
|
||||||
for (auto pair : llvm::zip(reductions, op.getResultTypes())) {
|
for (auto pair : llvm::zip(reductions, op.getResultTypes())) {
|
||||||
// For each of the reduction operations get the identity values for
|
// For each of the reduction operations get the identity values for
|
||||||
// initialization of the result values.
|
// initialization of the result values.
|
||||||
|
@ -234,7 +234,7 @@ public:
|
||||||
|
|
||||||
// Copy the body of the affine.parallel op.
|
// Copy the body of the affine.parallel op.
|
||||||
rewriter.eraseBlock(parOp.getBody());
|
rewriter.eraseBlock(parOp.getBody());
|
||||||
rewriter.inlineRegionBefore(op.region(), parOp.getRegion(),
|
rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(),
|
||||||
parOp.getRegion().end());
|
parOp.getRegion().end());
|
||||||
assert(reductions.size() == affineParOpTerminator->getNumOperands() &&
|
assert(reductions.size() == affineParOpTerminator->getNumOperands() &&
|
||||||
"Unequal number of reductions and operands.");
|
"Unequal number of reductions and operands.");
|
||||||
|
@ -299,13 +299,14 @@ public:
|
||||||
: rewriter.create<arith::ConstantIntOp>(loc, /*value=*/1,
|
: rewriter.create<arith::ConstantIntOp>(loc, /*value=*/1,
|
||||||
/*width=*/1);
|
/*width=*/1);
|
||||||
|
|
||||||
bool hasElseRegion = !op.elseRegion().empty();
|
bool hasElseRegion = !op.getElseRegion().empty();
|
||||||
auto ifOp = rewriter.create<scf::IfOp>(loc, op.getResultTypes(), cond,
|
auto ifOp = rewriter.create<scf::IfOp>(loc, op.getResultTypes(), cond,
|
||||||
hasElseRegion);
|
hasElseRegion);
|
||||||
rewriter.inlineRegionBefore(op.thenRegion(), &ifOp.getThenRegion().back());
|
rewriter.inlineRegionBefore(op.getThenRegion(),
|
||||||
|
&ifOp.getThenRegion().back());
|
||||||
rewriter.eraseBlock(&ifOp.getThenRegion().back());
|
rewriter.eraseBlock(&ifOp.getThenRegion().back());
|
||||||
if (hasElseRegion) {
|
if (hasElseRegion) {
|
||||||
rewriter.inlineRegionBefore(op.elseRegion(),
|
rewriter.inlineRegionBefore(op.getElseRegion(),
|
||||||
&ifOp.getElseRegion().back());
|
&ifOp.getElseRegion().back());
|
||||||
rewriter.eraseBlock(&ifOp.getElseRegion().back());
|
rewriter.eraseBlock(&ifOp.getElseRegion().back());
|
||||||
}
|
}
|
||||||
|
@ -375,8 +376,8 @@ public:
|
||||||
|
|
||||||
// Build memref.prefetch memref[expandedMap.results].
|
// Build memref.prefetch memref[expandedMap.results].
|
||||||
rewriter.replaceOpWithNewOp<memref::PrefetchOp>(
|
rewriter.replaceOpWithNewOp<memref::PrefetchOp>(
|
||||||
op, op.memref(), *resultOperands, op.isWrite(), op.localityHint(),
|
op, op.getMemref(), *resultOperands, op.getIsWrite(),
|
||||||
op.isDataCache());
|
op.getLocalityHint(), op.getIsDataCache());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -61,7 +61,7 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
|
||||||
return tensor::createPadScalarOp(RankedTensorType::get(paddedShape, inputETy),
|
return tensor::createPadScalarOp(RankedTensorType::get(paddedShape, inputETy),
|
||||||
input, padValue, lowIndices, highIndices,
|
input, padValue, lowIndices, highIndices,
|
||||||
/*nofold=*/false, loc, rewriter)
|
/*nofold=*/false, loc, rewriter)
|
||||||
.result();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
static mlir::Value reifyConstantDim(Attribute attr,
|
static mlir::Value reifyConstantDim(Attribute attr,
|
||||||
|
|
|
@ -889,12 +889,12 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
|
||||||
|
|
||||||
// Must be a tensor.cast op pair with matching types.
|
// Must be a tensor.cast op pair with matching types.
|
||||||
if (outgoingCastOp.getResult().getType() !=
|
if (outgoingCastOp.getResult().getType() !=
|
||||||
incomingCast.source().getType())
|
incomingCast.getSource().getType())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// Create a new ForOp with that iter operand replaced.
|
// Create a new ForOp with that iter operand replaced.
|
||||||
auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
|
auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
|
||||||
incomingCast.source());
|
incomingCast.getSource());
|
||||||
|
|
||||||
// Insert outgoing cast and use it to replace the corresponding result.
|
// Insert outgoing cast and use it to replace the corresponding result.
|
||||||
rewriter.setInsertionPointAfter(newForOp);
|
rewriter.setInsertionPointAfter(newForOp);
|
||||||
|
@ -902,7 +902,8 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
|
||||||
unsigned returnIdx =
|
unsigned returnIdx =
|
||||||
iterOpOperand.getOperandNumber() - op.getNumControlOperands();
|
iterOpOperand.getOperandNumber() - op.getNumControlOperands();
|
||||||
replacements[returnIdx] = rewriter.create<tensor::CastOp>(
|
replacements[returnIdx] = rewriter.create<tensor::CastOp>(
|
||||||
op.getLoc(), incomingCast.dest().getType(), replacements[returnIdx]);
|
op.getLoc(), incomingCast.getDest().getType(),
|
||||||
|
replacements[returnIdx]);
|
||||||
rewriter.replaceOp(op, replacements);
|
rewriter.replaceOp(op, replacements);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -97,7 +97,7 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
|
||||||
// Can fold if the source of cast has at least as much static information as
|
// Can fold if the source of cast has at least as much static information as
|
||||||
// its results.
|
// its results.
|
||||||
return preservesStaticInformation(castOp.getType(),
|
return preservesStaticInformation(castOp.getType(),
|
||||||
castOp.source().getType());
|
castOp.getSource().getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Determines whether the tensor::CastOp casts to a more static version of the
|
/// Determines whether the tensor::CastOp casts to a more static version of the
|
||||||
|
@ -123,7 +123,7 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
|
||||||
bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
|
bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
|
||||||
if (!castOp)
|
if (!castOp)
|
||||||
return false;
|
return false;
|
||||||
return preservesStaticInformation(castOp.source().getType(),
|
return preservesStaticInformation(castOp.getSource().getType(),
|
||||||
castOp.getType());
|
castOp.getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -250,13 +250,15 @@ struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
|
||||||
tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
|
tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
|
||||||
|
|
||||||
if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
|
if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
|
||||||
tensorCast.getType().getShape() ==
|
tensorCast.getType().getShape() == tensorCast.getSource()
|
||||||
tensorCast.source().getType().cast<RankedTensorType>().getShape())
|
.getType()
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getShape())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
|
SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
|
||||||
auto dimMask = computeRankReductionMask(
|
auto dimMask = computeRankReductionMask(
|
||||||
extractFromI64ArrayAttr(extractOperand.static_sizes()),
|
extractFromI64ArrayAttr(extractOperand.getStaticSizes()),
|
||||||
extractOperand.getType().getShape());
|
extractOperand.getType().getShape());
|
||||||
size_t dimIndex = 0;
|
size_t dimIndex = 0;
|
||||||
for (size_t i = 0, e = sizes.size(); i < e; i++) {
|
for (size_t i = 0, e = sizes.size(); i < e; i++) {
|
||||||
|
@ -270,7 +272,7 @@ struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<ExtractSliceOp>(
|
rewriter.replaceOpWithNewOp<ExtractSliceOp>(
|
||||||
tensorCast, tensorCast.getType().cast<RankedTensorType>(),
|
tensorCast, tensorCast.getType().cast<RankedTensorType>(),
|
||||||
extractOperand.source(), extractOperand.getMixedOffsets(), sizes,
|
extractOperand.getSource(), extractOperand.getMixedOffsets(), sizes,
|
||||||
extractOperand.getMixedStrides());
|
extractOperand.getMixedStrides());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -295,7 +297,7 @@ void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
|
||||||
}
|
}
|
||||||
|
|
||||||
Optional<int64_t> DimOp::getConstantIndex() {
|
Optional<int64_t> DimOp::getConstantIndex() {
|
||||||
if (auto constantOp = index().getDefiningOp<arith::ConstantOp>())
|
if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
|
||||||
return constantOp.getValue().cast<IntegerAttr>().getInt();
|
return constantOp.getValue().cast<IntegerAttr>().getInt();
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
@ -307,7 +309,7 @@ LogicalResult DimOp::verify() {
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// Check that constant index is not knowingly out of range.
|
// Check that constant index is not knowingly out of range.
|
||||||
auto type = source().getType();
|
auto type = getSource().getType();
|
||||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||||
if (*index >= tensorType.getRank())
|
if (*index >= tensorType.getRank())
|
||||||
return emitOpError("index is out of range");
|
return emitOpError("index is out of range");
|
||||||
|
@ -326,7 +328,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
// Folding for unranked types (UnrankedTensorType) is not supported.
|
// Folding for unranked types (UnrankedTensorType) is not supported.
|
||||||
auto tensorType = source().getType().dyn_cast<RankedTensorType>();
|
auto tensorType = getSource().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!tensorType)
|
if (!tensorType)
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
|
@ -336,7 +338,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
|
return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
|
||||||
}
|
}
|
||||||
|
|
||||||
Operation *definingOp = source().getDefiningOp();
|
Operation *definingOp = getSource().getDefiningOp();
|
||||||
|
|
||||||
// Fold dim to the operand of tensor.generate.
|
// Fold dim to the operand of tensor.generate.
|
||||||
if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
|
if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
|
||||||
|
@ -347,7 +349,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
||||||
assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
|
assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
|
||||||
|
|
||||||
// Find the operand of the fromElements that corresponds to this index.
|
// Find the operand of the fromElements that corresponds to this index.
|
||||||
auto dynExtents = fromElements.dynamicExtents().begin();
|
auto dynExtents = fromElements.getDynamicExtents().begin();
|
||||||
for (auto dim : resultType.getShape().take_front(index.getInt()))
|
for (auto dim : resultType.getShape().take_front(index.getInt()))
|
||||||
if (ShapedType::isDynamic(dim))
|
if (ShapedType::isDynamic(dim))
|
||||||
dynExtents++;
|
dynExtents++;
|
||||||
|
@ -381,11 +383,11 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(DimOp dimOp,
|
LogicalResult matchAndRewrite(DimOp dimOp,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto castOp = dimOp.source().getDefiningOp<CastOp>();
|
auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
|
||||||
if (!castOp)
|
if (!castOp)
|
||||||
return failure();
|
return failure();
|
||||||
Value newSource = castOp.getOperand();
|
Value newSource = castOp.getOperand();
|
||||||
rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
|
rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -402,8 +404,8 @@ void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||||
|
|
||||||
LogicalResult ExtractOp::verify() {
|
LogicalResult ExtractOp::verify() {
|
||||||
// Verify the # indices match if we have a ranked type.
|
// Verify the # indices match if we have a ranked type.
|
||||||
if (auto tensorType = tensor().getType().dyn_cast<RankedTensorType>())
|
if (auto tensorType = getTensor().getType().dyn_cast<RankedTensorType>())
|
||||||
if (tensorType.getRank() != static_cast<int64_t>(indices().size()))
|
if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
|
||||||
return emitOpError("incorrect number of indices for extract_element");
|
return emitOpError("incorrect number of indices for extract_element");
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -425,7 +427,7 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fold extract(from_elements(...)).
|
// Fold extract(from_elements(...)).
|
||||||
if (auto fromElementsOp = tensor().getDefiningOp<FromElementsOp>()) {
|
if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
|
||||||
auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
|
auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
|
||||||
auto rank = tensorType.getRank();
|
auto rank = tensorType.getRank();
|
||||||
assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
|
assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
|
||||||
|
@ -439,10 +441,10 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
|
||||||
}
|
}
|
||||||
// Prevent out of bounds accesses. This can happen in invalid code that will
|
// Prevent out of bounds accesses. This can happen in invalid code that will
|
||||||
// never execute.
|
// never execute.
|
||||||
if (static_cast<int>(fromElementsOp.elements().size()) <= flatIndex ||
|
if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
|
||||||
flatIndex < 0)
|
flatIndex < 0)
|
||||||
return {};
|
return {};
|
||||||
return fromElementsOp.elements()[flatIndex];
|
return fromElementsOp.getElements()[flatIndex];
|
||||||
}
|
}
|
||||||
|
|
||||||
// If this is an elements attribute, query the value at the given indices.
|
// If this is an elements attribute, query the value at the given indices.
|
||||||
|
@ -503,14 +505,14 @@ struct ExtractElementFromIndexCast
|
||||||
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
||||||
PatternRewriter &rewriter) const final {
|
PatternRewriter &rewriter) const final {
|
||||||
Location loc = extract.getLoc();
|
Location loc = extract.getLoc();
|
||||||
auto indexCast = extract.tensor().getDefiningOp<arith::IndexCastOp>();
|
auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
|
||||||
if (!indexCast)
|
if (!indexCast)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Type elementTy = getElementTypeOrSelf(indexCast.getIn());
|
Type elementTy = getElementTypeOrSelf(indexCast.getIn());
|
||||||
|
|
||||||
auto newExtract = rewriter.create<tensor::ExtractOp>(
|
auto newExtract = rewriter.create<tensor::ExtractOp>(
|
||||||
loc, elementTy, indexCast.getIn(), extract.indices());
|
loc, elementTy, indexCast.getIn(), extract.getIndices());
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
|
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
|
||||||
newExtract);
|
newExtract);
|
||||||
|
@ -532,8 +534,8 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||||
|
|
||||||
LogicalResult InsertOp::verify() {
|
LogicalResult InsertOp::verify() {
|
||||||
// Verify the # indices match if we have a ranked type.
|
// Verify the # indices match if we have a ranked type.
|
||||||
if (auto destType = dest().getType().dyn_cast<RankedTensorType>())
|
if (auto destType = getDest().getType().dyn_cast<RankedTensorType>())
|
||||||
if (destType.getRank() != static_cast<int64_t>(indices().size()))
|
if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
|
||||||
return emitOpError("incorrect number of indices");
|
return emitOpError("incorrect number of indices");
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -581,16 +583,16 @@ LogicalResult GenerateOp::verify() {
|
||||||
LogicalResult GenerateOp::verifyRegions() {
|
LogicalResult GenerateOp::verifyRegions() {
|
||||||
RankedTensorType resultTy = getType().cast<RankedTensorType>();
|
RankedTensorType resultTy = getType().cast<RankedTensorType>();
|
||||||
// Ensure that region arguments span the index space.
|
// Ensure that region arguments span the index space.
|
||||||
if (!llvm::all_of(body().getArgumentTypes(),
|
if (!llvm::all_of(getBody().getArgumentTypes(),
|
||||||
[](Type ty) { return ty.isIndex(); }))
|
[](Type ty) { return ty.isIndex(); }))
|
||||||
return emitError("all body arguments must be index");
|
return emitError("all body arguments must be index");
|
||||||
if (body().getNumArguments() != resultTy.getRank())
|
if (getBody().getNumArguments() != resultTy.getRank())
|
||||||
return emitError("must have one body argument per input dimension");
|
return emitError("must have one body argument per input dimension");
|
||||||
|
|
||||||
// Ensure that the region yields an element of the right type.
|
// Ensure that the region yields an element of the right type.
|
||||||
auto yieldOp = cast<YieldOp>(body().getBlocks().front().getTerminator());
|
auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
|
||||||
|
|
||||||
if (yieldOp.value().getType() != resultTy.getElementType())
|
if (yieldOp.getValue().getType() != resultTy.getElementType())
|
||||||
return emitOpError(
|
return emitOpError(
|
||||||
"body must be terminated with a `yield` operation of the tensor "
|
"body must be terminated with a `yield` operation of the tensor "
|
||||||
"element type");
|
"element type");
|
||||||
|
@ -634,7 +636,7 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
|
||||||
|
|
||||||
SmallVector<Value, 4> newOperands;
|
SmallVector<Value, 4> newOperands;
|
||||||
SmallVector<int64_t, 4> newShape;
|
SmallVector<int64_t, 4> newShape;
|
||||||
auto operandsIt = tensorFromElements.dynamicExtents().begin();
|
auto operandsIt = tensorFromElements.getDynamicExtents().begin();
|
||||||
|
|
||||||
for (int64_t dim : resultType.getShape()) {
|
for (int64_t dim : resultType.getShape()) {
|
||||||
if (!ShapedType::isDynamic(dim)) {
|
if (!ShapedType::isDynamic(dim)) {
|
||||||
|
@ -651,15 +653,15 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
|
||||||
operandsIt++;
|
operandsIt++;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (newOperands.size() == tensorFromElements.dynamicExtents().size())
|
if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto loc = tensorFromElements.getLoc();
|
auto loc = tensorFromElements.getLoc();
|
||||||
auto newOp = rewriter.create<GenerateOp>(
|
auto newOp = rewriter.create<GenerateOp>(
|
||||||
loc, RankedTensorType::get(newShape, resultType.getElementType()),
|
loc, RankedTensorType::get(newShape, resultType.getElementType()),
|
||||||
newOperands);
|
newOperands);
|
||||||
rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
|
rewriter.inlineRegionBefore(tensorFromElements.getBody(), newOp.getBody(),
|
||||||
newOp.body().begin());
|
newOp.getBody().begin());
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
|
||||||
newOp);
|
newOp);
|
||||||
return success();
|
return success();
|
||||||
|
@ -682,19 +684,19 @@ struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
||||||
PatternRewriter &rewriter) const final {
|
PatternRewriter &rewriter) const final {
|
||||||
auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
|
auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
|
||||||
if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
|
if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
BlockAndValueMapping mapping;
|
BlockAndValueMapping mapping;
|
||||||
Block *body = &tensorFromElements.getBody().front();
|
Block *body = &tensorFromElements.getBody().front();
|
||||||
mapping.map(body->getArguments(), extract.indices());
|
mapping.map(body->getArguments(), extract.getIndices());
|
||||||
for (auto &op : body->without_terminator())
|
for (auto &op : body->without_terminator())
|
||||||
rewriter.clone(op, mapping);
|
rewriter.clone(op, mapping);
|
||||||
|
|
||||||
auto yield = cast<YieldOp>(body->getTerminator());
|
auto yield = cast<YieldOp>(body->getTerminator());
|
||||||
|
|
||||||
rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
|
rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -712,12 +714,12 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
||||||
PatternRewriter &rewriter) const final {
|
PatternRewriter &rewriter) const final {
|
||||||
auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
|
auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
|
||||||
if (!tensorCast)
|
if (!tensorCast)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
|
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
|
||||||
extract.indices());
|
extract, tensorCast.getSource(), extract.getIndices());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -756,14 +758,15 @@ static int64_t getNumElements(ShapedType type) {
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult ReshapeOp::verify() {
|
LogicalResult ReshapeOp::verify() {
|
||||||
TensorType operandType = source().getType().cast<TensorType>();
|
TensorType operandType = getSource().getType().cast<TensorType>();
|
||||||
TensorType resultType = result().getType().cast<TensorType>();
|
TensorType resultType = getResult().getType().cast<TensorType>();
|
||||||
|
|
||||||
if (operandType.getElementType() != resultType.getElementType())
|
if (operandType.getElementType() != resultType.getElementType())
|
||||||
return emitOpError("element types of source and destination tensor "
|
return emitOpError("element types of source and destination tensor "
|
||||||
"types should be the same");
|
"types should be the same");
|
||||||
|
|
||||||
int64_t shapeSize = shape().getType().cast<RankedTensorType>().getDimSize(0);
|
int64_t shapeSize =
|
||||||
|
getShape().getType().cast<RankedTensorType>().getDimSize(0);
|
||||||
auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
|
auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
|
||||||
auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
|
auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
|
||||||
|
|
||||||
|
@ -891,7 +894,7 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
|
||||||
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
DenseElementsAttr attr;
|
DenseElementsAttr attr;
|
||||||
if (!matchPattern(reshapeOp.src(), m_Constant(&attr)))
|
if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
|
||||||
return failure();
|
return failure();
|
||||||
if (!attr || !attr.isSplat())
|
if (!attr || !attr.isSplat())
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -910,7 +913,7 @@ struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
|
||||||
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto fromElements =
|
auto fromElements =
|
||||||
reshapeOp.src().template getDefiningOp<FromElementsOp>();
|
reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
|
||||||
if (!fromElements)
|
if (!fromElements)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -920,7 +923,7 @@ struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
|
rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
|
||||||
fromElements.elements());
|
fromElements.getElements());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1208,7 +1211,7 @@ public:
|
||||||
}))
|
}))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto castOp = sliceOp.source().getDefiningOp<tensor::CastOp>();
|
auto castOp = sliceOp.getSource().getDefiningOp<tensor::CastOp>();
|
||||||
if (!castOp)
|
if (!castOp)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -1221,9 +1224,9 @@ public:
|
||||||
sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
|
sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
|
||||||
sliceOp.getMixedStrides());
|
sliceOp.getMixedStrides());
|
||||||
Value newSlice = rewriter.create<ExtractSliceOp>(
|
Value newSlice = rewriter.create<ExtractSliceOp>(
|
||||||
sliceOp.getLoc(), resultType, castOp.source(), sliceOp.offsets(),
|
sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(),
|
||||||
sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(),
|
sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
|
||||||
sliceOp.static_sizes(), sliceOp.static_strides());
|
sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
|
||||||
newSlice);
|
newSlice);
|
||||||
return success();
|
return success();
|
||||||
|
@ -1277,7 +1280,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(ExtractSliceOp op,
|
LogicalResult matchAndRewrite(ExtractSliceOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
DenseElementsAttr attr;
|
DenseElementsAttr attr;
|
||||||
if (!matchPattern(op.source(), m_Constant(&attr)))
|
if (!matchPattern(op.getSource(), m_Constant(&attr)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// A constant splat is handled by fold().
|
// A constant splat is handled by fold().
|
||||||
|
@ -1285,8 +1288,8 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Dynamic result shape is not supported.
|
// Dynamic result shape is not supported.
|
||||||
auto sourceType = op.source().getType().cast<ShapedType>();
|
auto sourceType = op.getSource().getType().cast<ShapedType>();
|
||||||
auto resultType = op.result().getType().cast<ShapedType>();
|
auto resultType = op.getResult().getType().cast<ShapedType>();
|
||||||
if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -1299,13 +1302,13 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Check if there are any dynamic parts, which are not supported.
|
// Check if there are any dynamic parts, which are not supported.
|
||||||
auto offsets = extractFromI64ArrayAttr(op.static_offsets());
|
auto offsets = extractFromI64ArrayAttr(op.getStaticOffsets());
|
||||||
if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset))
|
if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset))
|
||||||
return failure();
|
return failure();
|
||||||
auto sizes = extractFromI64ArrayAttr(op.static_sizes());
|
auto sizes = extractFromI64ArrayAttr(op.getStaticSizes());
|
||||||
if (llvm::is_contained(sizes, ShapedType::kDynamicSize))
|
if (llvm::is_contained(sizes, ShapedType::kDynamicSize))
|
||||||
return failure();
|
return failure();
|
||||||
auto strides = extractFromI64ArrayAttr(op.static_strides());
|
auto strides = extractFromI64ArrayAttr(op.getStaticStrides());
|
||||||
if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset))
|
if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -1414,25 +1417,25 @@ foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
|
||||||
// TODO: This only checks the immediate producer; extend to go up the
|
// TODO: This only checks the immediate producer; extend to go up the
|
||||||
// insert/extract chain if the slices are disjoint.
|
// insert/extract chain if the slices are disjoint.
|
||||||
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
|
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
|
||||||
auto insertOp = extractOp.source().getDefiningOp<InsertSliceOp>();
|
auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
|
||||||
|
|
||||||
auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
|
auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
|
||||||
if (insertOp && insertOp.source().getType() == extractOp.getType() &&
|
if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
|
||||||
insertOp.isSameAs(extractOp, isSame))
|
insertOp.isSameAs(extractOp, isSame))
|
||||||
return insertOp.source();
|
return insertOp.getSource();
|
||||||
|
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
|
||||||
if (auto splat = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
|
if (auto splat = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
|
||||||
auto resultType = result().getType().cast<ShapedType>();
|
auto resultType = getResult().getType().cast<ShapedType>();
|
||||||
if (resultType.hasStaticShape())
|
if (resultType.hasStaticShape())
|
||||||
return splat.resizeSplat(resultType);
|
return splat.resizeSplat(resultType);
|
||||||
}
|
}
|
||||||
if (getSourceType() == getType() &&
|
if (getSourceType() == getType() &&
|
||||||
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
|
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
|
||||||
return this->source();
|
return this->getSource();
|
||||||
if (Value slice = foldExtractAfterInsertSlice(*this))
|
if (Value slice = foldExtractAfterInsertSlice(*this))
|
||||||
return slice;
|
return slice;
|
||||||
|
|
||||||
|
@ -1518,8 +1521,8 @@ verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
|
||||||
LogicalResult InsertSliceOp::verify() {
|
LogicalResult InsertSliceOp::verify() {
|
||||||
ShapedType expectedType;
|
ShapedType expectedType;
|
||||||
auto result =
|
auto result =
|
||||||
verifyInsertSliceOp(getSourceType(), getType(), static_offsets(),
|
verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
|
||||||
static_sizes(), static_strides(), &expectedType);
|
getStaticSizes(), getStaticStrides(), &expectedType);
|
||||||
return produceSliceErrorMsg(result, *this, expectedType);
|
return produceSliceErrorMsg(result, *this, expectedType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1539,15 +1542,15 @@ LogicalResult InsertSliceOp::verify() {
|
||||||
/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
|
/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
|
||||||
/// ```
|
/// ```
|
||||||
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
|
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
|
||||||
auto prevInsertOp = insertOp.dest().getDefiningOp<InsertSliceOp>();
|
auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
|
||||||
|
|
||||||
auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
|
auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
|
||||||
if (!prevInsertOp ||
|
if (!prevInsertOp ||
|
||||||
prevInsertOp.source().getType() != insertOp.source().getType() ||
|
prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
|
||||||
!prevInsertOp.isSameAs(insertOp, isSame))
|
!prevInsertOp.isSameAs(insertOp, isSame))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
insertOp.destMutable().assign(prevInsertOp.dest());
|
insertOp.getDestMutable().assign(prevInsertOp.getDest());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1555,7 +1558,7 @@ OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
|
||||||
if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
|
if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
|
||||||
getSourceType() == getType() &&
|
getSourceType() == getType() &&
|
||||||
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
|
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
|
||||||
return this->source();
|
return this->getSource();
|
||||||
if (succeeded(foldInsertAfterInsertSlice(*this)))
|
if (succeeded(foldInsertAfterInsertSlice(*this)))
|
||||||
return getResult();
|
return getResult();
|
||||||
return OpFoldResult();
|
return OpFoldResult();
|
||||||
|
@ -1566,7 +1569,7 @@ LogicalResult InsertSliceOp::reifyResultShapes(
|
||||||
reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
|
reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
|
||||||
for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
|
for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
|
||||||
reifiedReturnShapes[0][dim] =
|
reifiedReturnShapes[0][dim] =
|
||||||
builder.createOrFold<tensor::DimOp>(getLoc(), dest(), dim);
|
builder.createOrFold<tensor::DimOp>(getLoc(), getDest(), dim);
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1600,13 +1603,13 @@ public:
|
||||||
auto sourceType = ExtractSliceOp::inferRankReducedResultType(
|
auto sourceType = ExtractSliceOp::inferRankReducedResultType(
|
||||||
insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(),
|
insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(),
|
||||||
mixedOffsets, mixedSizes, mixedStrides);
|
mixedOffsets, mixedSizes, mixedStrides);
|
||||||
Value toInsert = insertSliceOp.source();
|
Value toInsert = insertSliceOp.getSource();
|
||||||
if (sourceType != insertSliceOp.getSourceType())
|
if (sourceType != insertSliceOp.getSourceType())
|
||||||
toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
|
toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
|
||||||
sourceType, toInsert);
|
sourceType, toInsert);
|
||||||
rewriter.replaceOpWithNewOp<InsertSliceOp>(
|
rewriter.replaceOpWithNewOp<InsertSliceOp>(
|
||||||
insertSliceOp, toInsert, insertSliceOp.dest(), mixedOffsets, mixedSizes,
|
insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
|
||||||
mixedStrides);
|
mixedSizes, mixedStrides);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1643,22 +1646,23 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
|
||||||
auto castOp = v.getDefiningOp<tensor::CastOp>();
|
auto castOp = v.getDefiningOp<tensor::CastOp>();
|
||||||
if (!castOp || !canFoldIntoConsumerOp(castOp))
|
if (!castOp || !canFoldIntoConsumerOp(castOp))
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
return castOp.source();
|
return castOp.getSource();
|
||||||
};
|
};
|
||||||
Optional<Value> sourceCastSource =
|
Optional<Value> sourceCastSource =
|
||||||
getSourceOfCastOp(insertSliceOp.source());
|
getSourceOfCastOp(insertSliceOp.getSource());
|
||||||
Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.dest());
|
Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.getDest());
|
||||||
if (!sourceCastSource && !destCastSource)
|
if (!sourceCastSource && !destCastSource)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto src = (sourceCastSource ? *sourceCastSource : insertSliceOp.source());
|
auto src =
|
||||||
auto dst = (destCastSource ? *destCastSource : insertSliceOp.dest());
|
(sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
|
||||||
|
auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
|
||||||
|
|
||||||
auto srcType = src.getType().cast<ShapedType>();
|
auto srcType = src.getType().cast<ShapedType>();
|
||||||
auto dstType = dst.getType().cast<ShapedType>();
|
auto dstType = dst.getType().cast<ShapedType>();
|
||||||
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.static_offsets(),
|
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
|
||||||
insertSliceOp.static_sizes(),
|
insertSliceOp.getStaticSizes(),
|
||||||
insertSliceOp.static_strides()) !=
|
insertSliceOp.getStaticStrides()) !=
|
||||||
SliceVerificationResult::Success)
|
SliceVerificationResult::Success)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -1724,9 +1728,9 @@ struct InsertSliceOpSourceCastInserter final
|
||||||
// 3) Cast-compatible with srcType.
|
// 3) Cast-compatible with srcType.
|
||||||
// Insert the cast.
|
// Insert the cast.
|
||||||
Value cast = rewriter.create<tensor::CastOp>(
|
Value cast = rewriter.create<tensor::CastOp>(
|
||||||
insertSliceOp.getLoc(), newSrcType, insertSliceOp.source());
|
insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
|
||||||
rewriter.replaceOpWithNewOp<InsertSliceOp>(
|
rewriter.replaceOpWithNewOp<InsertSliceOp>(
|
||||||
insertSliceOp, cast, insertSliceOp.dest(),
|
insertSliceOp, cast, insertSliceOp.getDest(),
|
||||||
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
|
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
|
||||||
insertSliceOp.getMixedStrides());
|
insertSliceOp.getMixedStrides());
|
||||||
return success();
|
return success();
|
||||||
|
@ -1781,11 +1785,11 @@ ParseResult parseInferType(OpAsmParser &parser,
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult PadOp::verify() {
|
LogicalResult PadOp::verify() {
|
||||||
auto sourceType = source().getType().cast<RankedTensorType>();
|
auto sourceType = getSource().getType().cast<RankedTensorType>();
|
||||||
auto resultType = result().getType().cast<RankedTensorType>();
|
auto resultType = getResult().getType().cast<RankedTensorType>();
|
||||||
auto expectedType =
|
auto expectedType = PadOp::inferResultType(
|
||||||
PadOp::inferResultType(sourceType, extractFromI64ArrayAttr(static_low()),
|
sourceType, extractFromI64ArrayAttr(getStaticLow()),
|
||||||
extractFromI64ArrayAttr(static_high()));
|
extractFromI64ArrayAttr(getStaticHigh()));
|
||||||
for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
|
for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
|
||||||
if (resultType.getDimSize(i) == expectedType.getDimSize(i))
|
if (resultType.getDimSize(i) == expectedType.getDimSize(i))
|
||||||
continue;
|
continue;
|
||||||
|
@ -1801,7 +1805,7 @@ LogicalResult PadOp::verify() {
|
||||||
|
|
||||||
LogicalResult PadOp::verifyRegions() {
|
LogicalResult PadOp::verifyRegions() {
|
||||||
auto ®ion = getRegion();
|
auto ®ion = getRegion();
|
||||||
unsigned rank = result().getType().cast<RankedTensorType>().getRank();
|
unsigned rank = getResult().getType().cast<RankedTensorType>().getRank();
|
||||||
Block &block = region.front();
|
Block &block = region.front();
|
||||||
if (block.getNumArguments() != rank)
|
if (block.getNumArguments() != rank)
|
||||||
return emitError("expected the block to have ") << rank << " arguments";
|
return emitError("expected the block to have ") << rank << " arguments";
|
||||||
|
@ -1815,7 +1819,7 @@ LogicalResult PadOp::verifyRegions() {
|
||||||
|
|
||||||
// Ensure that the region yields an element of the right type.
|
// Ensure that the region yields an element of the right type.
|
||||||
auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
|
auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
|
||||||
if (yieldOp.value().getType() !=
|
if (yieldOp.getValue().getType() !=
|
||||||
getType().cast<ShapedType>().getElementType())
|
getType().cast<ShapedType>().getElementType())
|
||||||
return emitOpError("expected yield type to match shape element type");
|
return emitOpError("expected yield type to match shape element type");
|
||||||
|
|
||||||
|
@ -1919,10 +1923,11 @@ struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
|
if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
|
||||||
return failure();
|
return failure();
|
||||||
if (padTensorOp.nofold())
|
if (padTensorOp.getNofold())
|
||||||
return failure();
|
return failure();
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||||
padTensorOp, padTensorOp.result().getType(), padTensorOp.source());
|
padTensorOp, padTensorOp.getResult().getType(),
|
||||||
|
padTensorOp.getSource());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1933,25 +1938,26 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(PadOp padTensorOp,
|
LogicalResult matchAndRewrite(PadOp padTensorOp,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto castOp = padTensorOp.source().getDefiningOp<tensor::CastOp>();
|
auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
|
||||||
if (!tensor::canFoldIntoConsumerOp(castOp))
|
if (!tensor::canFoldIntoConsumerOp(castOp))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto newResultType = PadOp::inferResultType(
|
auto newResultType = PadOp::inferResultType(
|
||||||
castOp.source().getType().cast<RankedTensorType>(),
|
castOp.getSource().getType().cast<RankedTensorType>(),
|
||||||
extractFromI64ArrayAttr(padTensorOp.static_low()),
|
extractFromI64ArrayAttr(padTensorOp.getStaticLow()),
|
||||||
extractFromI64ArrayAttr(padTensorOp.static_high()),
|
extractFromI64ArrayAttr(padTensorOp.getStaticHigh()),
|
||||||
padTensorOp.getResultType().getShape());
|
padTensorOp.getResultType().getShape());
|
||||||
|
|
||||||
if (newResultType == padTensorOp.getResultType()) {
|
if (newResultType == padTensorOp.getResultType()) {
|
||||||
rewriter.updateRootInPlace(padTensorOp, [&]() {
|
rewriter.updateRootInPlace(padTensorOp, [&]() {
|
||||||
padTensorOp.sourceMutable().assign(castOp.source());
|
padTensorOp.getSourceMutable().assign(castOp.getSource());
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
auto newOp = rewriter.create<PadOp>(
|
auto newOp = rewriter.create<PadOp>(
|
||||||
padTensorOp->getLoc(), newResultType, padTensorOp.source(),
|
padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
|
||||||
padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(),
|
padTensorOp.getLow(), padTensorOp.getHigh(),
|
||||||
padTensorOp.static_high(), padTensorOp.nofold());
|
padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
|
||||||
|
padTensorOp.getNofold());
|
||||||
BlockAndValueMapping mapper;
|
BlockAndValueMapping mapper;
|
||||||
padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
|
padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
|
||||||
|
|
||||||
|
@ -1969,25 +1975,25 @@ struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(PadOp padTensorOp,
|
LogicalResult matchAndRewrite(PadOp padTensorOp,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
if (!padTensorOp.result().hasOneUse())
|
if (!padTensorOp.getResult().hasOneUse())
|
||||||
return failure();
|
return failure();
|
||||||
auto tensorCastOp =
|
auto tensorCastOp =
|
||||||
dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
|
dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
|
||||||
if (!tensorCastOp)
|
if (!tensorCastOp)
|
||||||
return failure();
|
return failure();
|
||||||
if (!tensor::preservesStaticInformation(padTensorOp.result().getType(),
|
if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
|
||||||
tensorCastOp.dest().getType()))
|
tensorCastOp.getDest().getType()))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto replacementOp = rewriter.create<PadOp>(
|
auto replacementOp = rewriter.create<PadOp>(
|
||||||
padTensorOp.getLoc(), tensorCastOp.dest().getType(),
|
padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
|
||||||
padTensorOp.source(), padTensorOp.low(), padTensorOp.high(),
|
padTensorOp.getSource(), padTensorOp.getLow(), padTensorOp.getHigh(),
|
||||||
padTensorOp.static_low(), padTensorOp.static_high(),
|
padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
|
||||||
padTensorOp.nofold());
|
padTensorOp.getNofold());
|
||||||
replacementOp.region().takeBody(padTensorOp.region());
|
replacementOp.getRegion().takeBody(padTensorOp.getRegion());
|
||||||
|
|
||||||
rewriter.replaceOp(padTensorOp, replacementOp.result());
|
rewriter.replaceOp(padTensorOp, replacementOp.getResult());
|
||||||
rewriter.replaceOp(tensorCastOp, replacementOp.result());
|
rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -2031,13 +2037,13 @@ struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(PadOp padOp,
|
LogicalResult matchAndRewrite(PadOp padOp,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto innerSliceOp = padOp.source().getDefiningOp<ExtractSliceOp>();
|
auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
|
||||||
if (!innerSliceOp)
|
if (!innerSliceOp)
|
||||||
return failure();
|
return failure();
|
||||||
auto outerPadOp = innerSliceOp.source().getDefiningOp<PadOp>();
|
auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
|
||||||
if (!outerPadOp || outerPadOp.nofold())
|
if (!outerPadOp || outerPadOp.getNofold())
|
||||||
return failure();
|
return failure();
|
||||||
auto outerSliceOp = outerPadOp.source().getDefiningOp<ExtractSliceOp>();
|
auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
|
||||||
if (!outerSliceOp)
|
if (!outerSliceOp)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -2136,11 +2142,11 @@ struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
|
||||||
// Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs the
|
// Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs the
|
||||||
// two paddings in one step.
|
// two paddings in one step.
|
||||||
auto newSliceOp = rewriter.create<ExtractSliceOp>(
|
auto newSliceOp = rewriter.create<ExtractSliceOp>(
|
||||||
padOp.getLoc(), outerSliceOp.source(), newOffsets, newSizes,
|
padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
|
||||||
innerSliceOp.getMixedStrides());
|
innerSliceOp.getMixedStrides());
|
||||||
auto newPadOp = rewriter.create<PadOp>(
|
auto newPadOp = rewriter.create<PadOp>(
|
||||||
padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
|
padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
|
||||||
padOp.getMixedLowPad(), newHighPad, padOp.nofold());
|
padOp.getMixedLowPad(), newHighPad, padOp.getNofold());
|
||||||
rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
|
rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
|
||||||
newPadOp.getRegion().begin());
|
newPadOp.getRegion().begin());
|
||||||
rewriter.replaceOp(padOp, newPadOp.getResult());
|
rewriter.replaceOp(padOp, newPadOp.getResult());
|
||||||
|
@ -2169,7 +2175,7 @@ Value PadOp::getConstantPaddingValue() {
|
||||||
auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
|
auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
|
||||||
if (!yieldOp)
|
if (!yieldOp)
|
||||||
return {};
|
return {};
|
||||||
Value padValue = yieldOp.value();
|
Value padValue = yieldOp.getValue();
|
||||||
// Check if yield value is a constant.
|
// Check if yield value is a constant.
|
||||||
if (matchPattern(padValue, m_Constant()))
|
if (matchPattern(padValue, m_Constant()))
|
||||||
return padValue;
|
return padValue;
|
||||||
|
@ -2182,8 +2188,8 @@ Value PadOp::getConstantPaddingValue() {
|
||||||
|
|
||||||
OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
|
OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
|
||||||
if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
|
if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
|
||||||
!nofold())
|
!getNofold())
|
||||||
return source();
|
return getSource();
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -75,7 +75,7 @@ struct SplitPadding final : public OpRewritePattern<tensor::PadOp> {
|
||||||
// Build the scf.if op itself. For the "then" branch, we can elide the
|
// Build the scf.if op itself. For the "then" branch, we can elide the
|
||||||
// padding. For the "else" branch, we retain the clone op.
|
// padding. For the "else" branch, we retain the clone op.
|
||||||
auto thenBuilder = [&padOp](OpBuilder &builder, Location loc) {
|
auto thenBuilder = [&padOp](OpBuilder &builder, Location loc) {
|
||||||
builder.create<scf::YieldOp>(loc, padOp.source());
|
builder.create<scf::YieldOp>(loc, padOp.getSource());
|
||||||
};
|
};
|
||||||
auto elseBuilder = [&padOp](OpBuilder &builder, Location loc) {
|
auto elseBuilder = [&padOp](OpBuilder &builder, Location loc) {
|
||||||
Operation *newOp = builder.clone(*padOp);
|
Operation *newOp = builder.clone(*padOp);
|
||||||
|
|
Loading…
Reference in New Issue