forked from OSchip/llvm-project
[mlir][Linalg] Fix fusion on tensors operands / bbArg mismatch
Linalg fusion on tensors has mismatching assumptions on the operand side than on the region bbArg side. Relax the behavior on the operand/indexing map side so that we better support output operands that may also be read from. Differential revision: https://reviews.llvm.org/D99499
This commit is contained in:
parent
916093f49f
commit
518e6f341d
|
@ -896,7 +896,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
|||
/*desc=*/[{
|
||||
Return the indexing maps within the current operation.
|
||||
}],
|
||||
/*retTy=*/"SmallVector<AffineMap, 4>",
|
||||
/*retTy=*/"SmallVector<AffineMap>",
|
||||
/*methodName=*/"getIndexingMaps",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
|
@ -931,6 +931,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
|||
return getIndexingMaps()[i];
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the input indexing maps.
|
||||
}],
|
||||
/*retTy=*/"SmallVector<AffineMap>",
|
||||
/*methodName=*/"getInputIndexingMaps",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
auto maps = $_op.getIndexingMaps();
|
||||
return SmallVector<AffineMap>{maps.begin(),
|
||||
maps.begin() + $_op.getNumInputs()};
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the output indexing map at index `i`.
|
||||
|
@ -944,6 +958,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
|||
return getIndexingMaps()[i + $_op.getNumInputs()];
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the output indexing maps.
|
||||
}],
|
||||
/*retTy=*/"SmallVector<AffineMap>",
|
||||
/*methodName=*/"getOutputIndexingMaps",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
auto maps = $_op.getIndexingMaps();
|
||||
return SmallVector<AffineMap>{maps.begin() + $_op.getNumInputs(),
|
||||
maps.begin() + $_op.getNumShapedOperands()};
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return whether the op has only MemRef input and outputs.
|
||||
|
|
|
@ -61,16 +61,14 @@ static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
|
|||
/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
|
||||
/// the `producer` to use in the fused operation given the indexing map of the
|
||||
/// result of the producer in the consumer.
|
||||
static void getIndexingMapOfProducerOperandsInFusedOp(
|
||||
LinalgOp producer, AffineMap fusedConsumerArgIndexMap,
|
||||
SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
|
||||
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
|
||||
OpOperand &producerOpOperand, AffineMap producerResultIndexMap,
|
||||
AffineMap fusedConsumerArgIndexMap) {
|
||||
// The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
|
||||
// from consumer loop -> consumer arg tensor index/producer result tensor
|
||||
// index. The fused loop is same as the consumer loop. For each producer arg
|
||||
// the indexing map to be computed is a map from consumer loop -> producer
|
||||
// arg tensor index.
|
||||
|
||||
AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
|
||||
// producerResultIndexMap is a map from producer loop -> tensor index.
|
||||
// Compute the inverse to get map from tensor index -> producer loop.
|
||||
// The inverse is a map from producer result tensor index -> producer loop.
|
||||
|
@ -78,19 +76,19 @@ static void getIndexingMapOfProducerOperandsInFusedOp(
|
|||
inversePermutation(producerResultIndexMap);
|
||||
assert(invProducerResultIndexMap &&
|
||||
"expected producer result indexig map to be invertible");
|
||||
for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) {
|
||||
// argMap is a map from producer loop -> producer arg tensor index.
|
||||
AffineMap argMap = producer.getInputIndexingMap(argNum);
|
||||
|
||||
// Compose argMap with invProducerResultIndexMap to get a map from
|
||||
// producer result tensor index -> producer arg tensor index.
|
||||
AffineMap t1 = argMap.compose(invProducerResultIndexMap);
|
||||
LinalgOp producer = cast<LinalgOp>(producerOpOperand.getOwner());
|
||||
// argMap is a map from producer loop -> producer arg tensor index.
|
||||
AffineMap argMap =
|
||||
producer.getIndexingMap(producerOpOperand.getOperandNumber());
|
||||
|
||||
// Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
|
||||
// consumer loop/ fused loop -> producer arg tensor index.
|
||||
AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap);
|
||||
fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap));
|
||||
}
|
||||
// Compose argMap with invProducerResultIndexMap to get a map from
|
||||
// producer result tensor index -> producer arg tensor index.
|
||||
AffineMap t1 = argMap.compose(invProducerResultIndexMap);
|
||||
|
||||
// Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
|
||||
// consumer loop/ fused loop -> producer arg tensor index.
|
||||
return t1.compose(fusedConsumerArgIndexMap);
|
||||
}
|
||||
|
||||
/// Generate the region of the fused tensor operation. The region of the fused
|
||||
|
@ -163,6 +161,18 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
|
|||
.drop_front(numProducerIndices)
|
||||
.take_front(producer.getNumInputs()))
|
||||
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
|
||||
|
||||
// 4.b. Producer output operand/map that is fused needs to be mapped to the
|
||||
// producer bbArg if it is an "initTensor" (i.e. its value is actually read).
|
||||
assert(producer->getNumResults() == 1 && "expected single result producer");
|
||||
if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) {
|
||||
BlockArgument bbArg =
|
||||
producerBlock.getArguments()
|
||||
.drop_front(numConsumerIndices + producer.getNumInputs())
|
||||
// TODO: bbArg index of
|
||||
.front();
|
||||
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
|
||||
}
|
||||
// 5. Remaining consumer's input operands (drop past index `consumerIdx`).
|
||||
for (BlockArgument bbArg : consumerBlock.getArguments()
|
||||
.drop_front(numConsumerIndices)
|
||||
|
@ -221,73 +231,90 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
|
|||
!controlFn(producer->getResult(0), consumerOpOperand))
|
||||
return llvm::None;
|
||||
|
||||
unsigned numFusedOperands =
|
||||
producer.getNumInputs() + consumer.getNumInputs() - 1;
|
||||
// TODO: allow fusing the producer of an output operand.
|
||||
assert(consumerIdx < consumer.getNumInputs() &&
|
||||
"expected producer of input operand");
|
||||
|
||||
// Compute the fused operands list,
|
||||
SmallVector<Value, 2> fusedOperands;
|
||||
fusedOperands.reserve(numFusedOperands);
|
||||
auto consumerOperands = consumer.getInputs();
|
||||
auto producerOperands = producer.getInputs();
|
||||
fusedOperands.assign(consumerOperands.begin(),
|
||||
std::next(consumerOperands.begin(), consumerIdx));
|
||||
fusedOperands.append(producerOperands.begin(), producerOperands.end());
|
||||
fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1),
|
||||
consumerOperands.end());
|
||||
|
||||
// Compute indexing_maps for the fused operation. The indexing_maps for the
|
||||
// operands of the consumers that aren't fused are the same. The
|
||||
// indexing_maps for the producers need to be computed based on the
|
||||
// indexing_map of the operand at consumerIdx in the consumer.
|
||||
SmallVector<Attribute, 4> fusedIndexMaps;
|
||||
auto consumerIndexMaps = consumer.indexing_maps();
|
||||
fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumOutputs());
|
||||
fusedIndexMaps.assign(consumerIndexMaps.begin(),
|
||||
std::next(consumerIndexMaps.begin(), consumerIdx));
|
||||
// Compute indexing maps for the producer args in the fused operation.
|
||||
getIndexingMapOfProducerOperandsInFusedOp(
|
||||
producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps);
|
||||
|
||||
// Append the indexing maps for the remaining consumer operands.
|
||||
fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1),
|
||||
consumerIndexMaps.end());
|
||||
// Compute the fused operands list and indexing maps.
|
||||
SmallVector<Value> fusedOperands;
|
||||
SmallVector<AffineMap> fusedIndexMaps;
|
||||
fusedOperands.reserve(producer->getNumOperands() +
|
||||
consumer->getNumOperands());
|
||||
fusedIndexMaps.reserve(producer->getNumOperands() +
|
||||
consumer->getNumOperands());
|
||||
// In the following, numbering matches that of `generateFusedTensorOpRegion`.
|
||||
// 3. Consumer input operands/maps up to consumerIdx (exclusive).
|
||||
llvm::append_range(fusedOperands,
|
||||
consumer.getInputs().take_front(consumerIdx));
|
||||
llvm::append_range(
|
||||
fusedIndexMaps,
|
||||
ArrayRef<AffineMap>{consumer.getInputIndexingMaps()}.take_front(
|
||||
consumerIdx));
|
||||
// 4. Splice in producer's input operands/maps.
|
||||
llvm::append_range(fusedOperands, producer.getInputs());
|
||||
assert(producer->getNumResults() == 1 && "expected single result producer");
|
||||
AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
|
||||
for (auto &inputOpOperand : producer.getInputOpOperands()) {
|
||||
// Compute indexing maps for the producer args in the fused operation.
|
||||
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
|
||||
inputOpOperand, producerResultIndexMap,
|
||||
consumer.getInputIndexingMap(consumerIdx));
|
||||
fusedIndexMaps.push_back(map);
|
||||
}
|
||||
// 4.b. Producer output operand/map that is fused needs to be passed if it is
|
||||
// an "initTensor" (i.e. its value is actually read).
|
||||
assert(producer->getNumResults() == 1 && "expected single result producer");
|
||||
if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) {
|
||||
llvm::append_range(fusedOperands, producer.getOutputs().take_front());
|
||||
// Compute indexing maps for the producer args in the fused operation.
|
||||
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
|
||||
producer.getOutputOpOperands().front(), producerResultIndexMap,
|
||||
consumer.getOutputIndexingMap(0));
|
||||
fusedIndexMaps.push_back(map);
|
||||
}
|
||||
// 5. Remaining consumer's input operands/maps (drop past index
|
||||
// `consumerIdx`).
|
||||
llvm::append_range(fusedOperands,
|
||||
consumer.getInputs().drop_front(consumerIdx + 1));
|
||||
llvm::append_range(
|
||||
fusedIndexMaps,
|
||||
ArrayRef<AffineMap>{consumer.getInputIndexingMaps()}.drop_front(
|
||||
consumerIdx + 1));
|
||||
// 6. All of consumer's output operands (skip operands: added by the builder).
|
||||
// llvm::append_range(fusedOperands, consumer.getOutputs());
|
||||
llvm::append_range(fusedIndexMaps, consumer.getOutputIndexingMaps());
|
||||
// 7. All of producer's output operands/maps except the one fused.
|
||||
// TODO: allow fusion of multi-result producers.
|
||||
assert(producer->getNumResults() == 1 && "expected single result producer");
|
||||
|
||||
// Generate the fused op.
|
||||
LinalgOp fusedOp;
|
||||
Operation *fusedOp;
|
||||
if (isa<GenericOp>(producer.getOperation()) &&
|
||||
isa<GenericOp>(consumer.getOperation())) {
|
||||
fusedOp =
|
||||
rewriter
|
||||
.create<GenericOp>(consumer.getLoc(), consumer->getResultTypes(),
|
||||
/*inputs=*/fusedOperands,
|
||||
// TODO: handle outputs.
|
||||
consumer.getOutputs(),
|
||||
rewriter.getArrayAttr(fusedIndexMaps),
|
||||
consumer.iterator_types(),
|
||||
/*doc=*/nullptr,
|
||||
/*library_call=*/nullptr,
|
||||
/*sparse=*/nullptr)
|
||||
.getOperation();
|
||||
fusedOp = rewriter.create<GenericOp>(
|
||||
consumer.getLoc(), consumer->getResultTypes(),
|
||||
/*inputs=*/fusedOperands,
|
||||
// TODO: handle outputs.
|
||||
consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps),
|
||||
consumer.iterator_types(),
|
||||
/*doc=*/nullptr,
|
||||
/*library_call=*/nullptr,
|
||||
/*sparse=*/nullptr);
|
||||
} else {
|
||||
fusedOp =
|
||||
rewriter
|
||||
.create<IndexedGenericOp>(
|
||||
consumer.getLoc(), consumer->getResultTypes(),
|
||||
/*inputs=*/fusedOperands,
|
||||
// TODO: handle outputs.
|
||||
consumer.getOutputs(), rewriter.getArrayAttr(fusedIndexMaps),
|
||||
consumer.iterator_types(),
|
||||
/*doc=*/nullptr,
|
||||
/*library_call=*/nullptr,
|
||||
/*sparse=*/nullptr)
|
||||
.getOperation();
|
||||
fusedOp = rewriter.create<IndexedGenericOp>(
|
||||
consumer.getLoc(), consumer->getResultTypes(),
|
||||
/*inputs=*/fusedOperands,
|
||||
// TODO: handle outputs.
|
||||
consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps),
|
||||
consumer.iterator_types(),
|
||||
/*doc=*/nullptr,
|
||||
/*library_call=*/nullptr,
|
||||
/*sparse=*/nullptr);
|
||||
}
|
||||
|
||||
// Construct an AffineMap from consumer loops to producer loops.
|
||||
// consumer loop -> tensor index
|
||||
AffineMap consumerResultIndexMap = consumer.getInputIndexingMap(consumerIdx);
|
||||
// producer loop -> tensor index
|
||||
AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
|
||||
// tensor index -> producer loop
|
||||
AffineMap invProducerResultIndexMap =
|
||||
inversePermutation(producerResultIndexMap);
|
||||
|
@ -297,9 +324,9 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
|
|||
AffineMap consumerToProducerLoopsMap =
|
||||
invProducerResultIndexMap.compose(consumerResultIndexMap);
|
||||
|
||||
generateFusedElementwiseOpRegion(rewriter, fusedOp.getOperation(), producer,
|
||||
consumer, consumerToProducerLoopsMap,
|
||||
consumerIdx, consumer.getNumLoops());
|
||||
generateFusedElementwiseOpRegion(rewriter, fusedOp, producer, consumer,
|
||||
consumerToProducerLoopsMap, consumerIdx,
|
||||
consumer.getNumLoops());
|
||||
return SmallVector<Value, 1>(fusedOp->getResults());
|
||||
}
|
||||
|
||||
|
|
|
@ -616,3 +616,39 @@ func @sigmoid_dynamic_dim(%0: tensor<?x1xf32>) -> tensor<?x1xf32> {
|
|||
} -> tensor<?x1xf32>
|
||||
return %2 : tensor<?x1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func private @compute1(%a: f64) -> f64
|
||||
func private @compute2(%a: f64, %b: i32) -> i32
|
||||
|
||||
// CHECK-LABEL: func @generic_index_op2(
|
||||
func @generic_index_op2(%arg0: tensor<1x8xf64>, %arg1: tensor<1x8xi32>) -> tensor<1x8xi32> {
|
||||
%0 = linalg.generic {
|
||||
indexing_maps = [affine_map<(i, j) -> (i, j)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
outs(%arg0 : tensor<1x8xf64>) {
|
||||
^bb0(%a: f64):
|
||||
%r = call @compute1(%a) : (f64) -> f64
|
||||
linalg.yield %r : f64
|
||||
} -> tensor<1x8xf64>
|
||||
|
||||
// CHECK-NEXT: %[[R:.*]] = linalg.generic
|
||||
// CHECK: bb0(%[[BBA:[0-9a-z]*]]: f64, %[[BBB:[0-9a-z]*]]: i32):
|
||||
// CHECK-NEXT: %[[A:.*]] = call @compute1(%[[BBA]]) : (f64) -> f64
|
||||
// CHECK-NEXT: %[[B:.*]] = call @compute2(%[[A]], %[[BBB]]) : (f64, i32) -> i32
|
||||
// CHECK-NEXT: linalg.yield %[[B]] : i32
|
||||
// CHECK-NEXT: } -> tensor<1x8xi32>
|
||||
%1 = linalg.generic {
|
||||
indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%0 : tensor<1x8xf64>)
|
||||
outs(%arg1 : tensor<1x8xi32>) {
|
||||
^bb0(%a: f64, %b: i32):
|
||||
%r = call @compute2(%a, %b) : (f64, i32) -> i32
|
||||
linalg.yield %r : i32
|
||||
} -> tensor<1x8xi32>
|
||||
|
||||
// CHECK-NEXT: return %[[R]] : tensor<1x8xi32>
|
||||
return %1 : tensor<1x8xi32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue