[mlir][linalg] Cleanup LinalgOp usage in fusion on tensors (NFC).

Replace the uses of deprecated Structured Op Interface methods in FusionOnTensors.cpp. This patch is based on https://reviews.llvm.org/D103394.

Differential Revision: https://reviews.llvm.org/D103471
This commit is contained in:
Tobias Gysi 2021-06-02 11:55:38 +00:00
parent 1cea1189c2
commit f84b908f89
1 changed files with 156 additions and 141 deletions

View File

@ -28,7 +28,7 @@ using namespace mlir::linalg;
/// Conditions for elementwise fusion of generic operations.
static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
unsigned consumerIdx) {
OpOperand *consumerOpOperand) {
// Producer and consumer must have tensor semantics.
if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
return false;
@ -40,12 +40,12 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
// Only allow fusing the producer of an input operand for now.
// TODO: allow fusing the producer of an output operand.
if (consumerIdx >= consumer.getNumInputs())
if (!consumer.isInputTensor(consumerOpOperand))
return false;
// Get the consumer index map. The number of results of the consumer index
// map must match the number of loops of the producer.
AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
if (consumerIndexMap.getNumResults() != producer.getNumLoops())
return false;
@ -55,7 +55,8 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
// Finally the index_map for the result must be invertible. For now just
// verify it is a permutation.
AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
AffineMap producerResultIndexMap =
producer.getTiedIndexingMap(producer.getOutputOperand(0));
return producerResultIndexMap.isPermutation();
}
@ -63,7 +64,7 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
/// the `producer` to use in the fused operation given the indexing map of the
/// result of the producer in the consumer.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
OpOperand &producerOpOperand, AffineMap producerResultIndexMap,
OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
AffineMap fusedConsumerArgIndexMap) {
// The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
// from consumer loop -> consumer arg tensor index/producer result tensor
@ -78,10 +79,9 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
assert(invProducerResultIndexMap &&
"expected producer result indexig map to be invertible");
LinalgOp producer = cast<LinalgOp>(producerOpOperand.getOwner());
LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
// argMap is a map from producer loop -> producer arg tensor index.
AffineMap argMap =
producer.getIndexingMap(producerOpOperand.getOperandNumber());
AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
// Compose argMap with invProducerResultIndexMap to get a map from
// producer result tensor index -> producer arg tensor index.
@ -96,9 +96,11 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
/// op must be empty.
static void
generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
GenericOp producer, GenericOp consumer,
AffineMap consumerToProducerLoopsMap,
unsigned consumerIdx, unsigned nloops) {
OpOperand *consumerOpOperand,
unsigned nloops) {
auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
// Build the region of the fused op.
Block &producerBlock = producer->getRegion(0).front();
Block &consumerBlock = consumer->getRegion(0).front();
@ -129,11 +131,11 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
}
}
// TODO: allow fusing the producer of an output operand.
assert(consumerIdx < consumer.getNumInputs() &&
assert(consumer.isInputTensor(consumerOpOperand) &&
"expected producer of input operand");
// 3. Consumer input operands up to consumerIdx (exclusive).
for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
consumerIdx)) // input assumption.
consumerOpOperand->getOperandNumber())) // input assumption.
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
// Replacing consumerIdx requires getting the cloned, yielded, value from
@ -147,7 +149,7 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
// 4.b. Producer output operand/map that is fused needs to be mapped to the
// producer bbArg if it is an "initTensor" (i.e. its value is actually read).
assert(producer->getNumResults() == 1 && "expected single result producer");
if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) {
if (producer.isInitTensor(producer.getOutputOperand(0))) {
BlockArgument bbArg = producerBlock.getArguments()
.drop_front(producer.getNumInputs())
// TODO: bbArg index of
@ -155,9 +157,10 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
}
// 5. Remaining consumer's input operands (drop past index `consumerIdx`).
for (BlockArgument bbArg : consumerBlock.getArguments()
.take_front(consumer.getNumInputs())
.drop_front(consumerIdx + 1))
for (BlockArgument bbArg :
consumerBlock.getArguments()
.take_front(consumer.getNumInputs())
.drop_front(consumerOpOperand->getOperandNumber() + 1))
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
// 6. All of consumer's output operands.
for (BlockArgument bbArg :
@ -191,7 +194,8 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
assert(!producer->isAncestor(replacement.getDefiningOp()) &&
"yielded value must have been mapped");
}
mapper.map(consumerBlock.getArgument(consumerIdx), replacement);
mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
replacement);
// 10. Clone operations from the consumer to the fused op.
for (auto &op : consumerBlock.getOperations())
rewriter.clone(op, mapper);
@ -202,17 +206,16 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
}
static Optional<SmallVector<Value>>
fuseElementwiseOpsImpl(GenericOp producer, OpOperand &consumerOpOperand,
fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
const ControlElementwiseOpsFusionFn &controlFn,
PatternRewriter &rewriter) {
auto consumer = cast<GenericOp>(consumerOpOperand.getOwner());
unsigned consumerIdx = consumerOpOperand.getOperandNumber();
if (!areElementwiseOpsFusable(producer, consumer, consumerIdx) ||
!controlFn(producer->getResult(0), consumerOpOperand))
auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
!controlFn(producer->getResult(0), *consumerOpOperand))
return llvm::None;
// TODO: allow fusing the producer of an output operand.
assert(consumerIdx < consumer.getNumInputs() &&
assert(consumer.isInputTensor(consumerOpOperand) &&
"expected producer of input operand");
// Compute the fused operands list and indexing maps.
@ -224,62 +227,66 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand &consumerOpOperand,
consumer->getNumOperands());
// In the following, numbering matches that of `generateFusedTensorOpRegion`.
// 3. Consumer input operands/maps up to consumerIdx (exclusive).
llvm::append_range(fusedOperands,
consumer.getInputs().take_front(consumerIdx));
llvm::append_range(
fusedIndexMaps,
ArrayRef<AffineMap>{consumer.getInputIndexingMaps()}.take_front(
consumerIdx));
SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
SmallVector<OpOperand *>::iterator it =
llvm::find(consumerInputs, consumerOpOperand);
assert(it != consumerInputs.end() && "expected to find the consumer operand");
for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
fusedOperands.push_back(opOperand->get());
fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
}
// 4. Splice in producer's input operands/maps.
llvm::append_range(fusedOperands, producer.getInputs());
assert(producer->getNumResults() == 1 && "expected single result producer");
AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
for (auto &inputOpOperand : producer.getInputOpOperands()) {
AffineMap producerResultIndexMap =
producer.getTiedIndexingMap(producer.getOutputOperand(0));
for (OpOperand *opOperand : producer.getInputOperands()) {
fusedOperands.push_back(opOperand->get());
// Compute indexing maps for the producer args in the fused operation.
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
inputOpOperand, producerResultIndexMap,
consumer.getInputIndexingMap(consumerIdx));
opOperand, producerResultIndexMap,
consumer.getTiedIndexingMap(consumerOpOperand));
fusedIndexMaps.push_back(map);
}
// 4.b. Producer output operand/map that is fused needs to be passed if it is
// an "initTensor" (i.e. its value is actually read).
assert(producer->getNumResults() == 1 && "expected single result producer");
if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) {
llvm::append_range(fusedOperands, producer.getOutputs().take_front());
if (producer.isInitTensor(producer.getOutputOperand(0))) {
fusedOperands.push_back(producer.getOutputOperand(0)->get());
// Compute indexing maps for the producer args in the fused operation.
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
producer.getOutputOpOperands().front(), producerResultIndexMap,
consumer.getOutputIndexingMap(0));
producer.getOutputOperand(0), producerResultIndexMap,
consumer.getTiedIndexingMap(consumerOpOperand));
fusedIndexMaps.push_back(map);
}
// 5. Remaining consumer's input operands/maps (drop past index
// `consumerIdx`).
llvm::append_range(fusedOperands,
consumer.getInputs().drop_front(consumerIdx + 1));
llvm::append_range(
fusedIndexMaps,
ArrayRef<AffineMap>{consumer.getInputIndexingMaps()}.drop_front(
consumerIdx + 1));
for (OpOperand *opOperand :
llvm::make_range(std::next(it), consumerInputs.end())) {
fusedOperands.push_back(opOperand->get());
fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
}
// 6. All of consumer's output operands (skip operands: added by the builder).
// llvm::append_range(fusedOperands, consumer.getOutputs());
llvm::append_range(fusedIndexMaps, consumer.getOutputIndexingMaps());
for (OpOperand *opOperand : consumer.getOutputOperands())
fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
// 7. All of producer's output operands/maps except the one fused.
// TODO: allow fusion of multi-result producers.
assert(producer->getNumResults() == 1 && "expected single result producer");
// Generate the fused op.
SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
auto fusedOp = rewriter.create<GenericOp>(
consumer.getLoc(), consumer->getResultTypes(),
/*inputs=*/fusedOperands,
// TODO: handle outputs.
consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps),
consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
consumer.iterator_types(),
/*doc=*/nullptr,
/*library_call=*/nullptr);
// Construct an AffineMap from consumer loops to producer loops.
// consumer loop -> tensor index
AffineMap consumerResultIndexMap = consumer.getInputIndexingMap(consumerIdx);
AffineMap consumerResultIndexMap =
consumer.getTiedIndexingMap(consumerOpOperand);
// tensor index -> producer loop
AffineMap invProducerResultIndexMap =
inversePermutation(producerResultIndexMap);
@ -289,9 +296,9 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand &consumerOpOperand,
AffineMap consumerToProducerLoopsMap =
invProducerResultIndexMap.compose(consumerResultIndexMap);
generateFusedElementwiseOpRegion(rewriter, fusedOp, producer, consumer,
consumerToProducerLoopsMap, consumerIdx,
consumer.getNumLoops());
generateFusedElementwiseOpRegion(rewriter, fusedOp,
consumerToProducerLoopsMap,
consumerOpOperand, consumer.getNumLoops());
return SmallVector<Value>(fusedOp->getResults());
}
@ -449,7 +456,7 @@ static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
/// The added reshapes are again expanding patterns, so they will get fused
/// with its producers if possible.
static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
unsigned fusedTensorIndex) {
OpOperand *fusableOpOperand) {
// Is fusable only if:
// - All the indexing maps for operands and results are projected
// permutations.
@ -462,7 +469,7 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
.getValue()
.isProjectedPermutation();
}) &&
genericOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 &&
genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
return attr.cast<StringAttr>().getValue() ==
getParallelIteratorTypeName();
@ -478,7 +485,7 @@ public:
// of the expanded op given the `indexingMap` of the fused operand/result of
// the generic op, the `reassocationMaps` of the reshape op and the shape of
// the expanded op.
LogicalResult compute(LinalgOp linalgOp, unsigned fusedTensorIndex,
LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
ArrayRef<int64_t> expandedShape,
PatternRewriter &rewriter);
@ -503,13 +510,13 @@ private:
} // namespace
LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
unsigned fusedTensorIndex,
OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
ArrayRef<int64_t> expandedShape,
PatternRewriter &rewriter) {
if (reassociationMaps.empty())
return failure();
AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex);
AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
Optional<SmallVector<int64_t, 4>> originalLoopRange =
linalgOp.getStaticLoopRanges();
@ -676,9 +683,9 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
/// been satisfied.
static Optional<SmallVector<Value>>
fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
unsigned fusedTensorIndex,
OpOperand *fusableOpOperand,
PatternRewriter &rewriter) {
assert(isFusableWithReshapeByDimExpansion(genericOp, fusedTensorIndex) &&
assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
"preconditions for fuse operation failed");
// Check if reshape is expanding or collapsing.
bool isExpanding =
@ -687,7 +694,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
ExpansionInfo expansionInfo;
if (failed(expansionInfo.compute(genericOp, fusedTensorIndex,
if (failed(expansionInfo.compute(genericOp, fusableOpOperand,
reshapeOp.getReassociationMaps(),
expandedType.getShape(), rewriter)))
return llvm::None;
@ -701,39 +708,39 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
}));
SmallVector<Value> expandedOpOperands;
for (auto operand : llvm::enumerate(genericOp.getInputs())) {
if (operand.index() == fusedTensorIndex) {
for (OpOperand *opOperand : genericOp.getInputOperands()) {
if (opOperand == fusableOpOperand) {
expandedOpOperands.push_back(reshapeOp.src());
continue;
}
AffineMap indexingMap = genericOp.getInputIndexingMap(operand.index());
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
RankedTensorType expandedOperandType =
getExpandedType(operand.value().getType().cast<RankedTensorType>(),
getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
indexingMap, expansionInfo);
if (expandedOperandType != operand.value().getType()) {
if (expandedOperandType != opOperand->get().getType()) {
// Reshape the operand to get the right type.
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
expandedOpOperands.push_back(rewriter.create<TensorReshapeOp>(
genericOp.getLoc(), expandedOperandType, operand.value(),
genericOp.getLoc(), expandedOperandType, opOperand->get(),
reassociation));
continue;
}
expandedOpOperands.push_back(operand.value());
expandedOpOperands.push_back(opOperand->get());
}
Location loc = genericOp.getLoc();
SmallVector<Value> outputs;
for (auto result : llvm::enumerate(genericOp.getOutputs())) {
AffineMap indexingMap = genericOp.getOutputIndexingMap(result.index());
for (OpOperand *opOperand : genericOp.getOutputOperands()) {
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
RankedTensorType expandedOutputType =
getExpandedType(result.value().getType().cast<RankedTensorType>(),
getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
indexingMap, expansionInfo);
if (expandedOutputType != result.value().getType()) {
if (expandedOutputType != opOperand->get().getType()) {
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
outputs.push_back(rewriter.create<TensorReshapeOp>(
genericOp.getLoc(), expandedOutputType, result.value(),
genericOp.getLoc(), expandedOutputType, opOperand->get(),
reassociation));
}
}
@ -757,17 +764,19 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
// Reshape the result values to their original shape if this is a collapsing
// reshape folded into its consumer.
SmallVector<Value> resultVals;
for (auto result : llvm::enumerate(genericOp->getResults())) {
if (!isExpanding &&
resultTypes[result.index()] != result.value().getType()) {
for (OpResult opResult : genericOp->getOpResults()) {
int64_t resultNumber = opResult.getResultNumber();
if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(
genericOp.getOutputIndexingMap(result.index()), expansionInfo);
genericOp.getTiedIndexingMap(
genericOp.getOutputOperand(resultNumber)),
expansionInfo);
resultVals.push_back(rewriter.create<TensorReshapeOp>(
genericOp.getLoc(), result.value().getType(),
fusedOp->getResult(result.index()), reassociation));
genericOp.getLoc(), opResult.getType(),
fusedOp->getResult(resultNumber), reassociation));
} else {
resultVals.push_back(fusedOp->getResult(result.index()));
resultVals.push_back(fusedOp->getResult(resultNumber));
}
}
// Assuming a single result.
@ -809,12 +818,13 @@ struct FoldProducerReshapeOpByLinearization
PatternRewriter &rewriter) const override {
if (!genericOp.hasTensorSemantics())
return failure();
for (auto operand : llvm::enumerate(genericOp.getInputs())) {
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
for (auto en : llvm::enumerate(inputOperands)) {
TensorReshapeOp reshapeOp =
operand.value().getDefiningOp<TensorReshapeOp>();
en.value()->get().getDefiningOp<TensorReshapeOp>();
if (!reshapeOp ||
!isTensorReshapeOpFoldableByLinearization(
reshapeOp, genericOp.getInputIndexingMap(operand.index()),
reshapeOp, genericOp.getTiedIndexingMap(en.value()),
/*asProducer =*/true) ||
(foldUnitDimReshapesOnly &&
!isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
@ -822,18 +832,17 @@ struct FoldProducerReshapeOpByLinearization
continue;
// Compute the fused operands list,
SmallVector<Value> fusedOperands(genericOp.getInputs());
fusedOperands[operand.index()] = reshapeOp.src();
fusedOperands.append(genericOp.getOutputs().begin(),
genericOp.getOutputs().end());
SmallVector<Value> fusedOperands = genericOp.getInputOperands();
fusedOperands[en.index()] = reshapeOp.src();
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
llvm::append_range(fusedOperands, outputOperands);
// Compute indexing_maps for the fused operation. The indexing_maps for
// the operands of the consumers that arent fused are the same.
SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
genericOp.indexing_maps().template getAsValueRange<AffineMapAttr>());
SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
// Accepted consumer maps are either identity or permutation.
auto invMap = inversePermutation(fusedIndexMaps[operand.index()]);
auto invMap = inversePermutation(fusedIndexMaps[en.index()]);
// Compute the indexing map to use for the result of the producer.
AffineMap modifiedMap =
@ -843,7 +852,7 @@ struct FoldProducerReshapeOpByLinearization
if (!expr.isPureAffine())
return failure();
}
fusedIndexMaps[operand.index()] = modifiedMap;
fusedIndexMaps[en.index()] = modifiedMap;
// Further check that the resulting index maps can be fused and
// inverted. Without this the resultant op is not legal.
@ -917,35 +926,36 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
return failure();
// Only support identity output maps. It could be extended to permuations if
// needed.
if (llvm::any_of(genericOp.getOutputIndexingMaps(),
[](AffineMap map) { return !map.isIdentity(); }))
if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
}))
return failure();
int64_t destRank = genericOp.getNumParallelLoops();
SmallVector<Value, 4> newOperands =
llvm::to_vector<4>(genericOp.getInputs());
SmallVector<Value> newOperands = genericOp.getInputOperands();
TensorReshapeOp reshapeFound;
// 1. Look for tensor_reshape operands and figure out save the dimensions
// merged.
for (auto operand : llvm::enumerate(genericOp.getInputs())) {
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
for (auto en : llvm::enumerate(inputOperands)) {
TensorReshapeOp reshapeOp =
operand.value().template getDefiningOp<TensorReshapeOp>();
en.value()->get().template getDefiningOp<TensorReshapeOp>();
if (!reshapeOp || reshapeOp.getSrcType().getRank() >
reshapeOp.getResultType().getRank()) {
continue;
}
// TODO: We could support non-identity map as long as the merged
// dimensions are still contiguous.
if (!genericOp.getIndexingMaps()[operand.index()].isIdentity())
if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
continue;
if (reshapeFound) {
// Only support a second reshape op if it has the same reassociate maps.
if (reshapeFound.getReassociationMaps() ==
reshapeOp.getReassociationMaps())
newOperands[operand.index()] = reshapeOp.src();
newOperands[en.index()] = reshapeOp.src();
continue;
}
reshapeFound = reshapeOp;
newOperands[operand.index()] = reshapeOp.src();
newOperands[en.index()] = reshapeOp.src();
}
if (!reshapeFound)
return failure();
@ -962,9 +972,9 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
// 2. Verify that we can merge the dimensions in the linalg and that we
// don't need to create new reshapes operands. Inserting new reshape
// operands would defeat the purpose of the transformation.
for (auto operand : llvm::enumerate(genericOp.getInputs())) {
if (operand.value() == newOperands[operand.index()]) {
AffineMap map = genericOp.getIndexingMaps()[operand.index()];
for (auto en : llvm::enumerate(inputOperands)) {
if (en.value()->get() == newOperands[en.index()]) {
AffineMap map = genericOp.getTiedIndexingMap(en.value());
for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
return failure();
@ -1036,9 +1046,9 @@ public:
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
for (auto operand : llvm::enumerate(genericOp.getInputs())) {
for (OpOperand *opOperand : genericOp.getInputOperands()) {
TensorReshapeOp reshapeOp =
operand.value().getDefiningOp<TensorReshapeOp>();
opOperand->get().getDefiningOp<TensorReshapeOp>();
if (!reshapeOp)
continue;
// Fold only if
@ -1046,15 +1056,12 @@ public:
// - All constraints of fusing with reshape by expansion are met.
if (reshapeOp.getSrcType().getRank() <
reshapeOp.getResultType().getRank() ||
!isFusableWithReshapeByDimExpansion(genericOp, operand.index()) ||
(!controlFoldingReshapes(
reshapeOp->getResult(0),
genericOp.getInputOpOperands()[operand.index()])))
!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
(!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
continue;
Optional<SmallVector<Value>> replacementValues =
fuseWithReshapeByExpansion(genericOp, reshapeOp, operand.index(),
rewriter);
fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
if (!replacementValues)
return failure();
rewriter.replaceOp(genericOp, replacementValues.getValue());
@ -1080,7 +1087,8 @@ struct FoldConsumerReshapeOpByLinearization
if (!producer || !producer.hasTensorSemantics() ||
producer.getNumOutputs() != 1 ||
!isTensorReshapeOpFoldableByLinearization(
reshapeOp, producer.getOutputIndexingMap(0),
reshapeOp,
producer.getTiedIndexingMap(producer.getOutputOperand(0)),
/*asProducer =*/false) ||
(foldUnitDimReshapesOnly &&
!isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
@ -1088,10 +1096,10 @@ struct FoldConsumerReshapeOpByLinearization
return failure();
// The indexing_maps for the operands of the fused operation are same as
// those for the operands of the producer.
SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
producer.indexing_maps().getAsValueRange<AffineMapAttr>());
SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
auto invMap = inversePermutation(producer.getOutputIndexingMap(0));
auto invMap = inversePermutation(
producer.getTiedIndexingMap(producer.getOutputOperand(0)));
// Compute the indexing map to use for the operand of the producer.
AffineMap modifiedMap =
@ -1113,11 +1121,13 @@ struct FoldConsumerReshapeOpByLinearization
}
Location loc = producer.getLoc();
SmallVector<Value> inputOperands = producer.getInputOperands();
Value output = rewriter.create<TensorReshapeOp>(
loc, producer.getOutputs()[0], reshapeOp.getReassociationExprs());
loc, producer.getOutputOperand(0)->get(),
reshapeOp.getReassociationExprs());
auto fusedOp = rewriter.create<GenericOp>(
loc, reshapeOp.getResultType(),
/*inputs=*/producer.getInputs(),
/*inputs=*/inputOperands,
// TODO: handle outputs.
/*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
producer.iterator_types(),
@ -1147,12 +1157,12 @@ struct FoldReshapeWithGenericOpByExpansion
GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
if (!producer || producer.getNumOutputs() != 1 ||
!isFusableWithReshapeByDimExpansion(producer,
producer.getNumInputs()) ||
producer.getOutputOperand(0)) ||
isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
reshapeOp.getReassociationMaps()))
return failure();
Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
producer, reshapeOp, producer.getNumInputs(), rewriter);
producer, reshapeOp, producer.getOutputOperand(0), rewriter);
if (!replacementValues)
return failure();
rewriter.replaceOp(reshapeOp, replacementValues.getValue());
@ -1171,21 +1181,29 @@ public:
PatternRewriter &rewriter) const override {
if (!genericOp.hasTensorSemantics())
return failure();
for (auto operand : llvm::enumerate(genericOp.getInputOpOperands())) {
Operation *def = operand.value().get().getDefiningOp();
for (OpOperand *opOperand : genericOp.getInputOperands()) {
Operation *def = opOperand->get().getDefiningOp();
DenseElementsAttr constantAttr;
if (!def ||
!matchPattern(def, m_Constant<DenseElementsAttr>(&constantAttr)) ||
!constantAttr.isSplat() ||
!controlFn(def->getResult(0), operand.value()))
!constantAttr.isSplat() || !controlFn(def->getResult(0), *opOperand))
continue;
// The indexing_maps for the operands of the fused operation are same as
// those for the operands of the genericOp without the indexing map at
// operand.index()
SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
genericOp.indexing_maps().getAsValueRange<AffineMapAttr>());
fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index()));
// The operands and the indexing_maps of the fused operation the same as
// the operands and indexing_maps of the generic operations with the
// values at the constant index dropped.
SmallVector<AffineMap> fusedIndexMaps;
SmallVector<Value> fusedOperands;
fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
fusedOperands.reserve(genericOp.getNumInputs());
for (OpOperand *inputOperand : genericOp.getInputOperands()) {
if (inputOperand == opOperand)
continue;
fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
fusedOperands.push_back(inputOperand->get());
}
for (OpOperand *outputOperand : genericOp.getOutputOperands())
fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
// Check if the operation shapes to loops map is computable.
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
@ -1193,20 +1211,16 @@ public:
genericOp, "fused op loop bound computation failed");
}
// The operands list is same as the genericOp with the argument for
// constant index dropped.
SmallVector<Value> fusedOperands(genericOp.getInputs());
fusedOperands.erase(std::next(fusedOperands.begin(), operand.index()));
// Create a constant scalar value from the splat constant.
Value scalarConstant = rewriter.create<ConstantOp>(
def->getLoc(), constantAttr.getSplatValue(),
constantAttr.getType().getElementType());
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
auto fusedOp = rewriter.create<GenericOp>(
rewriter.getUnknownLoc(), genericOp->getResultTypes(),
/*inputs=*/fusedOperands,
/*outputs=*/genericOp.getOutputs(),
/*outputs=*/outputOperands,
rewriter.getAffineMapArrayAttr(fusedIndexMaps),
genericOp.iterator_types(),
/*doc=*/nullptr,
@ -1217,7 +1231,8 @@ public:
Region &region = genericOp->getRegion(0);
Block &entryBlock = *region.begin();
BlockAndValueMapping mapping;
mapping.map(entryBlock.getArgument(operand.index()), scalarConstant);
mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
scalarConstant);
Region &fusedRegion = fusedOp->getRegion(0);
rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
mapping);
@ -1233,7 +1248,7 @@ private:
} // namespace
static Optional<SmallVector<Value>>
fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand,
fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
GenericOp producer,
const ControlElementwiseOpsFusionFn &controlFn) {
if (producer->getNumResults() != 1)
@ -1261,9 +1276,9 @@ public:
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
for (OpOperand &opOperand : genericOp.getShapedOpOperands()) {
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
auto producer =
dyn_cast_or_null<GenericOp>(opOperand.get().getDefiningOp());
dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
if (!producer || !producer.hasTensorSemantics())
continue;
Optional<SmallVector<Value>> fusedOpResults =
@ -1322,9 +1337,9 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
rewriter.startRootUpdate(op);
bool modifiedOutput = false;
Location loc = op.getLoc();
for (OpOperand &opOperand : op.getOutputOpOperands()) {
if (!op.payloadUsesValueFromOpOperand(&opOperand)) {
Value operandVal = opOperand.get();
for (OpOperand *opOperand : op.getOutputOperands()) {
if (!op.payloadUsesValueFromOperand(opOperand)) {
Value operandVal = opOperand->get();
auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
if (!operandType)
continue;
@ -1344,7 +1359,7 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
Value initTensor = rewriter.create<InitTensorOp>(
loc, dynamicDims, operandType.getShape(),
operandType.getElementType());
op->setOperand(opOperand.getOperandNumber(), initTensor);
op->setOperand(opOperand->getOperandNumber(), initTensor);
}
}
if (!modifiedOutput) {