forked from OSchip/llvm-project
[mlir][linalg] Cleanup LinalgOp usage in fusion on tensors (NFC).
Replace the uses of deprecated Structured Op Interface methods in FusionOnTensors.cpp. This patch is based on https://reviews.llvm.org/D103394. Differential Revision: https://reviews.llvm.org/D103471
This commit is contained in:
parent
1cea1189c2
commit
f84b908f89
|
@ -28,7 +28,7 @@ using namespace mlir::linalg;
|
|||
|
||||
/// Conditions for elementwise fusion of generic operations.
|
||||
static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
|
||||
unsigned consumerIdx) {
|
||||
OpOperand *consumerOpOperand) {
|
||||
// Producer and consumer must have tensor semantics.
|
||||
if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
|
||||
return false;
|
||||
|
@ -40,12 +40,12 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
|
|||
|
||||
// Only allow fusing the producer of an input operand for now.
|
||||
// TODO: allow fusing the producer of an output operand.
|
||||
if (consumerIdx >= consumer.getNumInputs())
|
||||
if (!consumer.isInputTensor(consumerOpOperand))
|
||||
return false;
|
||||
|
||||
// Get the consumer index map. The number of results of the consumer index
|
||||
// map must match the number of loops of the producer.
|
||||
AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
|
||||
AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
|
||||
if (consumerIndexMap.getNumResults() != producer.getNumLoops())
|
||||
return false;
|
||||
|
||||
|
@ -55,7 +55,8 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
|
|||
|
||||
// Finally the index_map for the result must be invertible. For now just
|
||||
// verify it is a permutation.
|
||||
AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
|
||||
AffineMap producerResultIndexMap =
|
||||
producer.getTiedIndexingMap(producer.getOutputOperand(0));
|
||||
return producerResultIndexMap.isPermutation();
|
||||
}
|
||||
|
||||
|
@ -63,7 +64,7 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
|
|||
/// the `producer` to use in the fused operation given the indexing map of the
|
||||
/// result of the producer in the consumer.
|
||||
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
|
||||
OpOperand &producerOpOperand, AffineMap producerResultIndexMap,
|
||||
OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
|
||||
AffineMap fusedConsumerArgIndexMap) {
|
||||
// The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
|
||||
// from consumer loop -> consumer arg tensor index/producer result tensor
|
||||
|
@ -78,10 +79,9 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
|
|||
assert(invProducerResultIndexMap &&
|
||||
"expected producer result indexig map to be invertible");
|
||||
|
||||
LinalgOp producer = cast<LinalgOp>(producerOpOperand.getOwner());
|
||||
LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
|
||||
// argMap is a map from producer loop -> producer arg tensor index.
|
||||
AffineMap argMap =
|
||||
producer.getIndexingMap(producerOpOperand.getOperandNumber());
|
||||
AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
|
||||
|
||||
// Compose argMap with invProducerResultIndexMap to get a map from
|
||||
// producer result tensor index -> producer arg tensor index.
|
||||
|
@ -96,9 +96,11 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
|
|||
/// op must be empty.
|
||||
static void
|
||||
generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
|
||||
GenericOp producer, GenericOp consumer,
|
||||
AffineMap consumerToProducerLoopsMap,
|
||||
unsigned consumerIdx, unsigned nloops) {
|
||||
OpOperand *consumerOpOperand,
|
||||
unsigned nloops) {
|
||||
auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
|
||||
auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
|
||||
// Build the region of the fused op.
|
||||
Block &producerBlock = producer->getRegion(0).front();
|
||||
Block &consumerBlock = consumer->getRegion(0).front();
|
||||
|
@ -129,11 +131,11 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
|
|||
}
|
||||
}
|
||||
// TODO: allow fusing the producer of an output operand.
|
||||
assert(consumerIdx < consumer.getNumInputs() &&
|
||||
assert(consumer.isInputTensor(consumerOpOperand) &&
|
||||
"expected producer of input operand");
|
||||
// 3. Consumer input operands up to consumerIdx (exclusive).
|
||||
for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
|
||||
consumerIdx)) // input assumption.
|
||||
consumerOpOperand->getOperandNumber())) // input assumption.
|
||||
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
|
||||
|
||||
// Replacing consumerIdx requires getting the cloned, yielded, value from
|
||||
|
@ -147,7 +149,7 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
|
|||
// 4.b. Producer output operand/map that is fused needs to be mapped to the
|
||||
// producer bbArg if it is an "initTensor" (i.e. its value is actually read).
|
||||
assert(producer->getNumResults() == 1 && "expected single result producer");
|
||||
if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) {
|
||||
if (producer.isInitTensor(producer.getOutputOperand(0))) {
|
||||
BlockArgument bbArg = producerBlock.getArguments()
|
||||
.drop_front(producer.getNumInputs())
|
||||
// TODO: bbArg index of
|
||||
|
@ -155,9 +157,10 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
|
|||
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
|
||||
}
|
||||
// 5. Remaining consumer's input operands (drop past index `consumerIdx`).
|
||||
for (BlockArgument bbArg : consumerBlock.getArguments()
|
||||
.take_front(consumer.getNumInputs())
|
||||
.drop_front(consumerIdx + 1))
|
||||
for (BlockArgument bbArg :
|
||||
consumerBlock.getArguments()
|
||||
.take_front(consumer.getNumInputs())
|
||||
.drop_front(consumerOpOperand->getOperandNumber() + 1))
|
||||
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
|
||||
// 6. All of consumer's output operands.
|
||||
for (BlockArgument bbArg :
|
||||
|
@ -191,7 +194,8 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
|
|||
assert(!producer->isAncestor(replacement.getDefiningOp()) &&
|
||||
"yielded value must have been mapped");
|
||||
}
|
||||
mapper.map(consumerBlock.getArgument(consumerIdx), replacement);
|
||||
mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
|
||||
replacement);
|
||||
// 10. Clone operations from the consumer to the fused op.
|
||||
for (auto &op : consumerBlock.getOperations())
|
||||
rewriter.clone(op, mapper);
|
||||
|
@ -202,17 +206,16 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
|
|||
}
|
||||
|
||||
static Optional<SmallVector<Value>>
|
||||
fuseElementwiseOpsImpl(GenericOp producer, OpOperand &consumerOpOperand,
|
||||
fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
|
||||
const ControlElementwiseOpsFusionFn &controlFn,
|
||||
PatternRewriter &rewriter) {
|
||||
auto consumer = cast<GenericOp>(consumerOpOperand.getOwner());
|
||||
unsigned consumerIdx = consumerOpOperand.getOperandNumber();
|
||||
if (!areElementwiseOpsFusable(producer, consumer, consumerIdx) ||
|
||||
!controlFn(producer->getResult(0), consumerOpOperand))
|
||||
auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
|
||||
if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
|
||||
!controlFn(producer->getResult(0), *consumerOpOperand))
|
||||
return llvm::None;
|
||||
|
||||
// TODO: allow fusing the producer of an output operand.
|
||||
assert(consumerIdx < consumer.getNumInputs() &&
|
||||
assert(consumer.isInputTensor(consumerOpOperand) &&
|
||||
"expected producer of input operand");
|
||||
|
||||
// Compute the fused operands list and indexing maps.
|
||||
|
@ -224,62 +227,66 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand &consumerOpOperand,
|
|||
consumer->getNumOperands());
|
||||
// In the following, numbering matches that of `generateFusedTensorOpRegion`.
|
||||
// 3. Consumer input operands/maps up to consumerIdx (exclusive).
|
||||
llvm::append_range(fusedOperands,
|
||||
consumer.getInputs().take_front(consumerIdx));
|
||||
llvm::append_range(
|
||||
fusedIndexMaps,
|
||||
ArrayRef<AffineMap>{consumer.getInputIndexingMaps()}.take_front(
|
||||
consumerIdx));
|
||||
SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
|
||||
SmallVector<OpOperand *>::iterator it =
|
||||
llvm::find(consumerInputs, consumerOpOperand);
|
||||
assert(it != consumerInputs.end() && "expected to find the consumer operand");
|
||||
for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
|
||||
fusedOperands.push_back(opOperand->get());
|
||||
fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
|
||||
}
|
||||
// 4. Splice in producer's input operands/maps.
|
||||
llvm::append_range(fusedOperands, producer.getInputs());
|
||||
assert(producer->getNumResults() == 1 && "expected single result producer");
|
||||
AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
|
||||
for (auto &inputOpOperand : producer.getInputOpOperands()) {
|
||||
AffineMap producerResultIndexMap =
|
||||
producer.getTiedIndexingMap(producer.getOutputOperand(0));
|
||||
for (OpOperand *opOperand : producer.getInputOperands()) {
|
||||
fusedOperands.push_back(opOperand->get());
|
||||
// Compute indexing maps for the producer args in the fused operation.
|
||||
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
|
||||
inputOpOperand, producerResultIndexMap,
|
||||
consumer.getInputIndexingMap(consumerIdx));
|
||||
opOperand, producerResultIndexMap,
|
||||
consumer.getTiedIndexingMap(consumerOpOperand));
|
||||
fusedIndexMaps.push_back(map);
|
||||
}
|
||||
// 4.b. Producer output operand/map that is fused needs to be passed if it is
|
||||
// an "initTensor" (i.e. its value is actually read).
|
||||
assert(producer->getNumResults() == 1 && "expected single result producer");
|
||||
if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) {
|
||||
llvm::append_range(fusedOperands, producer.getOutputs().take_front());
|
||||
if (producer.isInitTensor(producer.getOutputOperand(0))) {
|
||||
fusedOperands.push_back(producer.getOutputOperand(0)->get());
|
||||
// Compute indexing maps for the producer args in the fused operation.
|
||||
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
|
||||
producer.getOutputOpOperands().front(), producerResultIndexMap,
|
||||
consumer.getOutputIndexingMap(0));
|
||||
producer.getOutputOperand(0), producerResultIndexMap,
|
||||
consumer.getTiedIndexingMap(consumerOpOperand));
|
||||
fusedIndexMaps.push_back(map);
|
||||
}
|
||||
// 5. Remaining consumer's input operands/maps (drop past index
|
||||
// `consumerIdx`).
|
||||
llvm::append_range(fusedOperands,
|
||||
consumer.getInputs().drop_front(consumerIdx + 1));
|
||||
llvm::append_range(
|
||||
fusedIndexMaps,
|
||||
ArrayRef<AffineMap>{consumer.getInputIndexingMaps()}.drop_front(
|
||||
consumerIdx + 1));
|
||||
for (OpOperand *opOperand :
|
||||
llvm::make_range(std::next(it), consumerInputs.end())) {
|
||||
fusedOperands.push_back(opOperand->get());
|
||||
fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
|
||||
}
|
||||
// 6. All of consumer's output operands (skip operands: added by the builder).
|
||||
// llvm::append_range(fusedOperands, consumer.getOutputs());
|
||||
llvm::append_range(fusedIndexMaps, consumer.getOutputIndexingMaps());
|
||||
for (OpOperand *opOperand : consumer.getOutputOperands())
|
||||
fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
|
||||
// 7. All of producer's output operands/maps except the one fused.
|
||||
// TODO: allow fusion of multi-result producers.
|
||||
assert(producer->getNumResults() == 1 && "expected single result producer");
|
||||
|
||||
// Generate the fused op.
|
||||
SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
|
||||
auto fusedOp = rewriter.create<GenericOp>(
|
||||
consumer.getLoc(), consumer->getResultTypes(),
|
||||
/*inputs=*/fusedOperands,
|
||||
// TODO: handle outputs.
|
||||
consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps),
|
||||
consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
|
||||
consumer.iterator_types(),
|
||||
/*doc=*/nullptr,
|
||||
/*library_call=*/nullptr);
|
||||
|
||||
// Construct an AffineMap from consumer loops to producer loops.
|
||||
// consumer loop -> tensor index
|
||||
AffineMap consumerResultIndexMap = consumer.getInputIndexingMap(consumerIdx);
|
||||
AffineMap consumerResultIndexMap =
|
||||
consumer.getTiedIndexingMap(consumerOpOperand);
|
||||
// tensor index -> producer loop
|
||||
AffineMap invProducerResultIndexMap =
|
||||
inversePermutation(producerResultIndexMap);
|
||||
|
@ -289,9 +296,9 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand &consumerOpOperand,
|
|||
AffineMap consumerToProducerLoopsMap =
|
||||
invProducerResultIndexMap.compose(consumerResultIndexMap);
|
||||
|
||||
generateFusedElementwiseOpRegion(rewriter, fusedOp, producer, consumer,
|
||||
consumerToProducerLoopsMap, consumerIdx,
|
||||
consumer.getNumLoops());
|
||||
generateFusedElementwiseOpRegion(rewriter, fusedOp,
|
||||
consumerToProducerLoopsMap,
|
||||
consumerOpOperand, consumer.getNumLoops());
|
||||
return SmallVector<Value>(fusedOp->getResults());
|
||||
}
|
||||
|
||||
|
@ -449,7 +456,7 @@ static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
|
|||
/// The added reshapes are again expanding patterns, so they will get fused
|
||||
/// with its producers if possible.
|
||||
static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
|
||||
unsigned fusedTensorIndex) {
|
||||
OpOperand *fusableOpOperand) {
|
||||
// Is fusable only if:
|
||||
// - All the indexing maps for operands and results are projected
|
||||
// permutations.
|
||||
|
@ -462,7 +469,7 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
|
|||
.getValue()
|
||||
.isProjectedPermutation();
|
||||
}) &&
|
||||
genericOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 &&
|
||||
genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
|
||||
llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
|
||||
return attr.cast<StringAttr>().getValue() ==
|
||||
getParallelIteratorTypeName();
|
||||
|
@ -478,7 +485,7 @@ public:
|
|||
// of the expanded op given the `indexingMap` of the fused operand/result of
|
||||
// the generic op, the `reassocationMaps` of the reshape op and the shape of
|
||||
// the expanded op.
|
||||
LogicalResult compute(LinalgOp linalgOp, unsigned fusedTensorIndex,
|
||||
LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
|
||||
ArrayRef<AffineMap> reassociationMaps,
|
||||
ArrayRef<int64_t> expandedShape,
|
||||
PatternRewriter &rewriter);
|
||||
|
@ -503,13 +510,13 @@ private:
|
|||
} // namespace
|
||||
|
||||
LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
|
||||
unsigned fusedTensorIndex,
|
||||
OpOperand *fusableOpOperand,
|
||||
ArrayRef<AffineMap> reassociationMaps,
|
||||
ArrayRef<int64_t> expandedShape,
|
||||
PatternRewriter &rewriter) {
|
||||
if (reassociationMaps.empty())
|
||||
return failure();
|
||||
AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex);
|
||||
AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
|
||||
|
||||
Optional<SmallVector<int64_t, 4>> originalLoopRange =
|
||||
linalgOp.getStaticLoopRanges();
|
||||
|
@ -676,9 +683,9 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
|
|||
/// been satisfied.
|
||||
static Optional<SmallVector<Value>>
|
||||
fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
|
||||
unsigned fusedTensorIndex,
|
||||
OpOperand *fusableOpOperand,
|
||||
PatternRewriter &rewriter) {
|
||||
assert(isFusableWithReshapeByDimExpansion(genericOp, fusedTensorIndex) &&
|
||||
assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
|
||||
"preconditions for fuse operation failed");
|
||||
// Check if reshape is expanding or collapsing.
|
||||
bool isExpanding =
|
||||
|
@ -687,7 +694,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
|
|||
isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
|
||||
|
||||
ExpansionInfo expansionInfo;
|
||||
if (failed(expansionInfo.compute(genericOp, fusedTensorIndex,
|
||||
if (failed(expansionInfo.compute(genericOp, fusableOpOperand,
|
||||
reshapeOp.getReassociationMaps(),
|
||||
expandedType.getShape(), rewriter)))
|
||||
return llvm::None;
|
||||
|
@ -701,39 +708,39 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
|
|||
}));
|
||||
|
||||
SmallVector<Value> expandedOpOperands;
|
||||
for (auto operand : llvm::enumerate(genericOp.getInputs())) {
|
||||
if (operand.index() == fusedTensorIndex) {
|
||||
for (OpOperand *opOperand : genericOp.getInputOperands()) {
|
||||
if (opOperand == fusableOpOperand) {
|
||||
expandedOpOperands.push_back(reshapeOp.src());
|
||||
continue;
|
||||
}
|
||||
AffineMap indexingMap = genericOp.getInputIndexingMap(operand.index());
|
||||
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
|
||||
RankedTensorType expandedOperandType =
|
||||
getExpandedType(operand.value().getType().cast<RankedTensorType>(),
|
||||
getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
|
||||
indexingMap, expansionInfo);
|
||||
if (expandedOperandType != operand.value().getType()) {
|
||||
if (expandedOperandType != opOperand->get().getType()) {
|
||||
// Reshape the operand to get the right type.
|
||||
SmallVector<ReassociationIndices> reassociation =
|
||||
getReassociationForExpansion(indexingMap, expansionInfo);
|
||||
expandedOpOperands.push_back(rewriter.create<TensorReshapeOp>(
|
||||
genericOp.getLoc(), expandedOperandType, operand.value(),
|
||||
genericOp.getLoc(), expandedOperandType, opOperand->get(),
|
||||
reassociation));
|
||||
continue;
|
||||
}
|
||||
expandedOpOperands.push_back(operand.value());
|
||||
expandedOpOperands.push_back(opOperand->get());
|
||||
}
|
||||
|
||||
Location loc = genericOp.getLoc();
|
||||
SmallVector<Value> outputs;
|
||||
for (auto result : llvm::enumerate(genericOp.getOutputs())) {
|
||||
AffineMap indexingMap = genericOp.getOutputIndexingMap(result.index());
|
||||
for (OpOperand *opOperand : genericOp.getOutputOperands()) {
|
||||
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
|
||||
RankedTensorType expandedOutputType =
|
||||
getExpandedType(result.value().getType().cast<RankedTensorType>(),
|
||||
getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
|
||||
indexingMap, expansionInfo);
|
||||
if (expandedOutputType != result.value().getType()) {
|
||||
if (expandedOutputType != opOperand->get().getType()) {
|
||||
SmallVector<ReassociationIndices> reassociation =
|
||||
getReassociationForExpansion(indexingMap, expansionInfo);
|
||||
outputs.push_back(rewriter.create<TensorReshapeOp>(
|
||||
genericOp.getLoc(), expandedOutputType, result.value(),
|
||||
genericOp.getLoc(), expandedOutputType, opOperand->get(),
|
||||
reassociation));
|
||||
}
|
||||
}
|
||||
|
@ -757,17 +764,19 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
|
|||
// Reshape the result values to their original shape if this is a collapsing
|
||||
// reshape folded into its consumer.
|
||||
SmallVector<Value> resultVals;
|
||||
for (auto result : llvm::enumerate(genericOp->getResults())) {
|
||||
if (!isExpanding &&
|
||||
resultTypes[result.index()] != result.value().getType()) {
|
||||
for (OpResult opResult : genericOp->getOpResults()) {
|
||||
int64_t resultNumber = opResult.getResultNumber();
|
||||
if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
|
||||
SmallVector<ReassociationIndices> reassociation =
|
||||
getReassociationForExpansion(
|
||||
genericOp.getOutputIndexingMap(result.index()), expansionInfo);
|
||||
genericOp.getTiedIndexingMap(
|
||||
genericOp.getOutputOperand(resultNumber)),
|
||||
expansionInfo);
|
||||
resultVals.push_back(rewriter.create<TensorReshapeOp>(
|
||||
genericOp.getLoc(), result.value().getType(),
|
||||
fusedOp->getResult(result.index()), reassociation));
|
||||
genericOp.getLoc(), opResult.getType(),
|
||||
fusedOp->getResult(resultNumber), reassociation));
|
||||
} else {
|
||||
resultVals.push_back(fusedOp->getResult(result.index()));
|
||||
resultVals.push_back(fusedOp->getResult(resultNumber));
|
||||
}
|
||||
}
|
||||
// Assuming a single result.
|
||||
|
@ -809,12 +818,13 @@ struct FoldProducerReshapeOpByLinearization
|
|||
PatternRewriter &rewriter) const override {
|
||||
if (!genericOp.hasTensorSemantics())
|
||||
return failure();
|
||||
for (auto operand : llvm::enumerate(genericOp.getInputs())) {
|
||||
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
|
||||
for (auto en : llvm::enumerate(inputOperands)) {
|
||||
TensorReshapeOp reshapeOp =
|
||||
operand.value().getDefiningOp<TensorReshapeOp>();
|
||||
en.value()->get().getDefiningOp<TensorReshapeOp>();
|
||||
if (!reshapeOp ||
|
||||
!isTensorReshapeOpFoldableByLinearization(
|
||||
reshapeOp, genericOp.getInputIndexingMap(operand.index()),
|
||||
reshapeOp, genericOp.getTiedIndexingMap(en.value()),
|
||||
/*asProducer =*/true) ||
|
||||
(foldUnitDimReshapesOnly &&
|
||||
!isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
|
||||
|
@ -822,18 +832,17 @@ struct FoldProducerReshapeOpByLinearization
|
|||
continue;
|
||||
|
||||
// Compute the fused operands list,
|
||||
SmallVector<Value> fusedOperands(genericOp.getInputs());
|
||||
fusedOperands[operand.index()] = reshapeOp.src();
|
||||
fusedOperands.append(genericOp.getOutputs().begin(),
|
||||
genericOp.getOutputs().end());
|
||||
SmallVector<Value> fusedOperands = genericOp.getInputOperands();
|
||||
fusedOperands[en.index()] = reshapeOp.src();
|
||||
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
|
||||
llvm::append_range(fusedOperands, outputOperands);
|
||||
|
||||
// Compute indexing_maps for the fused operation. The indexing_maps for
|
||||
// the operands of the consumers that arent fused are the same.
|
||||
SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
|
||||
genericOp.indexing_maps().template getAsValueRange<AffineMapAttr>());
|
||||
SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
|
||||
|
||||
// Accepted consumer maps are either identity or permutation.
|
||||
auto invMap = inversePermutation(fusedIndexMaps[operand.index()]);
|
||||
auto invMap = inversePermutation(fusedIndexMaps[en.index()]);
|
||||
|
||||
// Compute the indexing map to use for the result of the producer.
|
||||
AffineMap modifiedMap =
|
||||
|
@ -843,7 +852,7 @@ struct FoldProducerReshapeOpByLinearization
|
|||
if (!expr.isPureAffine())
|
||||
return failure();
|
||||
}
|
||||
fusedIndexMaps[operand.index()] = modifiedMap;
|
||||
fusedIndexMaps[en.index()] = modifiedMap;
|
||||
|
||||
// Further check that the resulting index maps can be fused and
|
||||
// inverted. Without this the resultant op is not legal.
|
||||
|
@ -917,35 +926,36 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
|
|||
return failure();
|
||||
// Only support identity output maps. It could be extended to permuations if
|
||||
// needed.
|
||||
if (llvm::any_of(genericOp.getOutputIndexingMaps(),
|
||||
[](AffineMap map) { return !map.isIdentity(); }))
|
||||
if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
|
||||
return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
|
||||
}))
|
||||
return failure();
|
||||
int64_t destRank = genericOp.getNumParallelLoops();
|
||||
SmallVector<Value, 4> newOperands =
|
||||
llvm::to_vector<4>(genericOp.getInputs());
|
||||
SmallVector<Value> newOperands = genericOp.getInputOperands();
|
||||
TensorReshapeOp reshapeFound;
|
||||
// 1. Look for tensor_reshape operands and figure out save the dimensions
|
||||
// merged.
|
||||
for (auto operand : llvm::enumerate(genericOp.getInputs())) {
|
||||
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
|
||||
for (auto en : llvm::enumerate(inputOperands)) {
|
||||
TensorReshapeOp reshapeOp =
|
||||
operand.value().template getDefiningOp<TensorReshapeOp>();
|
||||
en.value()->get().template getDefiningOp<TensorReshapeOp>();
|
||||
if (!reshapeOp || reshapeOp.getSrcType().getRank() >
|
||||
reshapeOp.getResultType().getRank()) {
|
||||
continue;
|
||||
}
|
||||
// TODO: We could support non-identity map as long as the merged
|
||||
// dimensions are still contiguous.
|
||||
if (!genericOp.getIndexingMaps()[operand.index()].isIdentity())
|
||||
if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
|
||||
continue;
|
||||
if (reshapeFound) {
|
||||
// Only support a second reshape op if it has the same reassociate maps.
|
||||
if (reshapeFound.getReassociationMaps() ==
|
||||
reshapeOp.getReassociationMaps())
|
||||
newOperands[operand.index()] = reshapeOp.src();
|
||||
newOperands[en.index()] = reshapeOp.src();
|
||||
continue;
|
||||
}
|
||||
reshapeFound = reshapeOp;
|
||||
newOperands[operand.index()] = reshapeOp.src();
|
||||
newOperands[en.index()] = reshapeOp.src();
|
||||
}
|
||||
if (!reshapeFound)
|
||||
return failure();
|
||||
|
@ -962,9 +972,9 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
|
|||
// 2. Verify that we can merge the dimensions in the linalg and that we
|
||||
// don't need to create new reshapes operands. Inserting new reshape
|
||||
// operands would defeat the purpose of the transformation.
|
||||
for (auto operand : llvm::enumerate(genericOp.getInputs())) {
|
||||
if (operand.value() == newOperands[operand.index()]) {
|
||||
AffineMap map = genericOp.getIndexingMaps()[operand.index()];
|
||||
for (auto en : llvm::enumerate(inputOperands)) {
|
||||
if (en.value()->get() == newOperands[en.index()]) {
|
||||
AffineMap map = genericOp.getTiedIndexingMap(en.value());
|
||||
for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
|
||||
if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
|
||||
return failure();
|
||||
|
@ -1036,9 +1046,9 @@ public:
|
|||
|
||||
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
for (auto operand : llvm::enumerate(genericOp.getInputs())) {
|
||||
for (OpOperand *opOperand : genericOp.getInputOperands()) {
|
||||
TensorReshapeOp reshapeOp =
|
||||
operand.value().getDefiningOp<TensorReshapeOp>();
|
||||
opOperand->get().getDefiningOp<TensorReshapeOp>();
|
||||
if (!reshapeOp)
|
||||
continue;
|
||||
// Fold only if
|
||||
|
@ -1046,15 +1056,12 @@ public:
|
|||
// - All constraints of fusing with reshape by expansion are met.
|
||||
if (reshapeOp.getSrcType().getRank() <
|
||||
reshapeOp.getResultType().getRank() ||
|
||||
!isFusableWithReshapeByDimExpansion(genericOp, operand.index()) ||
|
||||
(!controlFoldingReshapes(
|
||||
reshapeOp->getResult(0),
|
||||
genericOp.getInputOpOperands()[operand.index()])))
|
||||
!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
|
||||
(!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
|
||||
continue;
|
||||
|
||||
Optional<SmallVector<Value>> replacementValues =
|
||||
fuseWithReshapeByExpansion(genericOp, reshapeOp, operand.index(),
|
||||
rewriter);
|
||||
fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
|
||||
if (!replacementValues)
|
||||
return failure();
|
||||
rewriter.replaceOp(genericOp, replacementValues.getValue());
|
||||
|
@ -1080,7 +1087,8 @@ struct FoldConsumerReshapeOpByLinearization
|
|||
if (!producer || !producer.hasTensorSemantics() ||
|
||||
producer.getNumOutputs() != 1 ||
|
||||
!isTensorReshapeOpFoldableByLinearization(
|
||||
reshapeOp, producer.getOutputIndexingMap(0),
|
||||
reshapeOp,
|
||||
producer.getTiedIndexingMap(producer.getOutputOperand(0)),
|
||||
/*asProducer =*/false) ||
|
||||
(foldUnitDimReshapesOnly &&
|
||||
!isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
|
||||
|
@ -1088,10 +1096,10 @@ struct FoldConsumerReshapeOpByLinearization
|
|||
return failure();
|
||||
// The indexing_maps for the operands of the fused operation are same as
|
||||
// those for the operands of the producer.
|
||||
SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
|
||||
producer.indexing_maps().getAsValueRange<AffineMapAttr>());
|
||||
SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
|
||||
|
||||
auto invMap = inversePermutation(producer.getOutputIndexingMap(0));
|
||||
auto invMap = inversePermutation(
|
||||
producer.getTiedIndexingMap(producer.getOutputOperand(0)));
|
||||
|
||||
// Compute the indexing map to use for the operand of the producer.
|
||||
AffineMap modifiedMap =
|
||||
|
@ -1113,11 +1121,13 @@ struct FoldConsumerReshapeOpByLinearization
|
|||
}
|
||||
|
||||
Location loc = producer.getLoc();
|
||||
SmallVector<Value> inputOperands = producer.getInputOperands();
|
||||
Value output = rewriter.create<TensorReshapeOp>(
|
||||
loc, producer.getOutputs()[0], reshapeOp.getReassociationExprs());
|
||||
loc, producer.getOutputOperand(0)->get(),
|
||||
reshapeOp.getReassociationExprs());
|
||||
auto fusedOp = rewriter.create<GenericOp>(
|
||||
loc, reshapeOp.getResultType(),
|
||||
/*inputs=*/producer.getInputs(),
|
||||
/*inputs=*/inputOperands,
|
||||
// TODO: handle outputs.
|
||||
/*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
|
||||
producer.iterator_types(),
|
||||
|
@ -1147,12 +1157,12 @@ struct FoldReshapeWithGenericOpByExpansion
|
|||
GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
|
||||
if (!producer || producer.getNumOutputs() != 1 ||
|
||||
!isFusableWithReshapeByDimExpansion(producer,
|
||||
producer.getNumInputs()) ||
|
||||
producer.getOutputOperand(0)) ||
|
||||
isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
|
||||
reshapeOp.getReassociationMaps()))
|
||||
return failure();
|
||||
Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
|
||||
producer, reshapeOp, producer.getNumInputs(), rewriter);
|
||||
producer, reshapeOp, producer.getOutputOperand(0), rewriter);
|
||||
if (!replacementValues)
|
||||
return failure();
|
||||
rewriter.replaceOp(reshapeOp, replacementValues.getValue());
|
||||
|
@ -1171,21 +1181,29 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
if (!genericOp.hasTensorSemantics())
|
||||
return failure();
|
||||
for (auto operand : llvm::enumerate(genericOp.getInputOpOperands())) {
|
||||
Operation *def = operand.value().get().getDefiningOp();
|
||||
for (OpOperand *opOperand : genericOp.getInputOperands()) {
|
||||
Operation *def = opOperand->get().getDefiningOp();
|
||||
DenseElementsAttr constantAttr;
|
||||
if (!def ||
|
||||
!matchPattern(def, m_Constant<DenseElementsAttr>(&constantAttr)) ||
|
||||
!constantAttr.isSplat() ||
|
||||
!controlFn(def->getResult(0), operand.value()))
|
||||
!constantAttr.isSplat() || !controlFn(def->getResult(0), *opOperand))
|
||||
continue;
|
||||
|
||||
// The indexing_maps for the operands of the fused operation are same as
|
||||
// those for the operands of the genericOp without the indexing map at
|
||||
// operand.index()
|
||||
SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
|
||||
genericOp.indexing_maps().getAsValueRange<AffineMapAttr>());
|
||||
fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index()));
|
||||
// The operands and the indexing_maps of the fused operation the same as
|
||||
// the operands and indexing_maps of the generic operations with the
|
||||
// values at the constant index dropped.
|
||||
SmallVector<AffineMap> fusedIndexMaps;
|
||||
SmallVector<Value> fusedOperands;
|
||||
fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
|
||||
fusedOperands.reserve(genericOp.getNumInputs());
|
||||
for (OpOperand *inputOperand : genericOp.getInputOperands()) {
|
||||
if (inputOperand == opOperand)
|
||||
continue;
|
||||
fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
|
||||
fusedOperands.push_back(inputOperand->get());
|
||||
}
|
||||
for (OpOperand *outputOperand : genericOp.getOutputOperands())
|
||||
fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
|
||||
|
||||
// Check if the operation shapes to loops map is computable.
|
||||
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
|
||||
|
@ -1193,20 +1211,16 @@ public:
|
|||
genericOp, "fused op loop bound computation failed");
|
||||
}
|
||||
|
||||
// The operands list is same as the genericOp with the argument for
|
||||
// constant index dropped.
|
||||
SmallVector<Value> fusedOperands(genericOp.getInputs());
|
||||
fusedOperands.erase(std::next(fusedOperands.begin(), operand.index()));
|
||||
|
||||
// Create a constant scalar value from the splat constant.
|
||||
Value scalarConstant = rewriter.create<ConstantOp>(
|
||||
def->getLoc(), constantAttr.getSplatValue(),
|
||||
constantAttr.getType().getElementType());
|
||||
|
||||
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
|
||||
auto fusedOp = rewriter.create<GenericOp>(
|
||||
rewriter.getUnknownLoc(), genericOp->getResultTypes(),
|
||||
/*inputs=*/fusedOperands,
|
||||
/*outputs=*/genericOp.getOutputs(),
|
||||
/*outputs=*/outputOperands,
|
||||
rewriter.getAffineMapArrayAttr(fusedIndexMaps),
|
||||
genericOp.iterator_types(),
|
||||
/*doc=*/nullptr,
|
||||
|
@ -1217,7 +1231,8 @@ public:
|
|||
Region ®ion = genericOp->getRegion(0);
|
||||
Block &entryBlock = *region.begin();
|
||||
BlockAndValueMapping mapping;
|
||||
mapping.map(entryBlock.getArgument(operand.index()), scalarConstant);
|
||||
mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
|
||||
scalarConstant);
|
||||
Region &fusedRegion = fusedOp->getRegion(0);
|
||||
rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
|
||||
mapping);
|
||||
|
@ -1233,7 +1248,7 @@ private:
|
|||
} // namespace
|
||||
|
||||
static Optional<SmallVector<Value>>
|
||||
fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand,
|
||||
fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
|
||||
GenericOp producer,
|
||||
const ControlElementwiseOpsFusionFn &controlFn) {
|
||||
if (producer->getNumResults() != 1)
|
||||
|
@ -1261,9 +1276,9 @@ public:
|
|||
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Find the first operand that is defined by another generic op on tensors.
|
||||
for (OpOperand &opOperand : genericOp.getShapedOpOperands()) {
|
||||
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
|
||||
auto producer =
|
||||
dyn_cast_or_null<GenericOp>(opOperand.get().getDefiningOp());
|
||||
dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
|
||||
if (!producer || !producer.hasTensorSemantics())
|
||||
continue;
|
||||
Optional<SmallVector<Value>> fusedOpResults =
|
||||
|
@ -1322,9 +1337,9 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
|
|||
rewriter.startRootUpdate(op);
|
||||
bool modifiedOutput = false;
|
||||
Location loc = op.getLoc();
|
||||
for (OpOperand &opOperand : op.getOutputOpOperands()) {
|
||||
if (!op.payloadUsesValueFromOpOperand(&opOperand)) {
|
||||
Value operandVal = opOperand.get();
|
||||
for (OpOperand *opOperand : op.getOutputOperands()) {
|
||||
if (!op.payloadUsesValueFromOperand(opOperand)) {
|
||||
Value operandVal = opOperand->get();
|
||||
auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
|
||||
if (!operandType)
|
||||
continue;
|
||||
|
@ -1344,7 +1359,7 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
|
|||
Value initTensor = rewriter.create<InitTensorOp>(
|
||||
loc, dynamicDims, operandType.getShape(),
|
||||
operandType.getElementType());
|
||||
op->setOperand(opOperand.getOperandNumber(), initTensor);
|
||||
op->setOperand(opOperand->getOperandNumber(), initTensor);
|
||||
}
|
||||
}
|
||||
if (!modifiedOutput) {
|
||||
|
|
Loading…
Reference in New Issue