diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 4fdd9a2221f0..95e008aacc45 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -896,7 +896,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*desc=*/[{ Return the indexing maps within the current operation. }], - /*retTy=*/"SmallVector", + /*retTy=*/"SmallVector", /*methodName=*/"getIndexingMaps", /*args=*/(ins), /*methodBody=*/"", @@ -931,6 +931,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { return getIndexingMaps()[i]; }] >, + InterfaceMethod< + /*desc=*/[{ + Return the input indexing maps. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getInputIndexingMaps", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto maps = $_op.getIndexingMaps(); + return SmallVector{maps.begin(), + maps.begin() + $_op.getNumInputs()}; + }] + >, InterfaceMethod< /*desc=*/[{ Return the output indexing map at index `i`. @@ -944,6 +958,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { return getIndexingMaps()[i + $_op.getNumInputs()]; }] >, + InterfaceMethod< + /*desc=*/[{ + Return the output indexing maps. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getOutputIndexingMaps", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto maps = $_op.getIndexingMaps(); + return SmallVector{maps.begin() + $_op.getNumInputs(), + maps.begin() + $_op.getNumShapedOperands()}; + }] + >, InterfaceMethod< /*desc=*/[{ Return whether the op has only MemRef input and outputs. diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index bb1a051c78e5..34eac4bdfcaa 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -61,16 +61,14 @@ static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer, /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of /// the `producer` to use in the fused operation given the indexing map of the /// result of the producer in the consumer. -static void getIndexingMapOfProducerOperandsInFusedOp( - LinalgOp producer, AffineMap fusedConsumerArgIndexMap, - SmallVectorImpl &fusedOpIndexingMapAttrs) { +static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( + OpOperand &producerOpOperand, AffineMap producerResultIndexMap, + AffineMap fusedConsumerArgIndexMap) { // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map // from consumer loop -> consumer arg tensor index/producer result tensor // index. The fused loop is same as the consumer loop. For each producer arg // the indexing map to be computed is a map from consumer loop -> producer // arg tensor index. - - AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); // producerResultIndexMap is a map from producer loop -> tensor index. // Compute the inverse to get map from tensor index -> producer loop. // The inverse is a map from producer result tensor index -> producer loop. @@ -78,19 +76,19 @@ static void getIndexingMapOfProducerOperandsInFusedOp( inversePermutation(producerResultIndexMap); assert(invProducerResultIndexMap && "expected producer result indexig map to be invertible"); - for (unsigned argNum : llvm::seq(0, producer.getNumInputs())) { - // argMap is a map from producer loop -> producer arg tensor index. - AffineMap argMap = producer.getInputIndexingMap(argNum); - // Compose argMap with invProducerResultIndexMap to get a map from - // producer result tensor index -> producer arg tensor index. - AffineMap t1 = argMap.compose(invProducerResultIndexMap); + LinalgOp producer = cast(producerOpOperand.getOwner()); + // argMap is a map from producer loop -> producer arg tensor index. + AffineMap argMap = + producer.getIndexingMap(producerOpOperand.getOperandNumber()); - // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from - // consumer loop/ fused loop -> producer arg tensor index. - AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap); - fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap)); - } + // Compose argMap with invProducerResultIndexMap to get a map from + // producer result tensor index -> producer arg tensor index. + AffineMap t1 = argMap.compose(invProducerResultIndexMap); + + // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from + // consumer loop/ fused loop -> producer arg tensor index. + return t1.compose(fusedConsumerArgIndexMap); } /// Generate the region of the fused tensor operation. The region of the fused @@ -163,6 +161,18 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp, .drop_front(numProducerIndices) .take_front(producer.getNumInputs())) mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); + + // 4.b. Producer output operand/map that is fused needs to be mapped to the + // producer bbArg if it is an "initTensor" (i.e. its value is actually read). + assert(producer->getNumResults() == 1 && "expected single result producer"); + if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) { + BlockArgument bbArg = + producerBlock.getArguments() + .drop_front(numConsumerIndices + producer.getNumInputs()) + // TODO: bbArg index of + .front(); + mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); + } // 5. Remaining consumer's input operands (drop past index `consumerIdx`). for (BlockArgument bbArg : consumerBlock.getArguments() .drop_front(numConsumerIndices) @@ -221,73 +231,90 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand, !controlFn(producer->getResult(0), consumerOpOperand)) return llvm::None; - unsigned numFusedOperands = - producer.getNumInputs() + consumer.getNumInputs() - 1; + // TODO: allow fusing the producer of an output operand. + assert(consumerIdx < consumer.getNumInputs() && + "expected producer of input operand"); - // Compute the fused operands list, - SmallVector fusedOperands; - fusedOperands.reserve(numFusedOperands); - auto consumerOperands = consumer.getInputs(); - auto producerOperands = producer.getInputs(); - fusedOperands.assign(consumerOperands.begin(), - std::next(consumerOperands.begin(), consumerIdx)); - fusedOperands.append(producerOperands.begin(), producerOperands.end()); - fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1), - consumerOperands.end()); - - // Compute indexing_maps for the fused operation. The indexing_maps for the - // operands of the consumers that aren't fused are the same. The - // indexing_maps for the producers need to be computed based on the - // indexing_map of the operand at consumerIdx in the consumer. - SmallVector fusedIndexMaps; - auto consumerIndexMaps = consumer.indexing_maps(); - fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumOutputs()); - fusedIndexMaps.assign(consumerIndexMaps.begin(), - std::next(consumerIndexMaps.begin(), consumerIdx)); - // Compute indexing maps for the producer args in the fused operation. - getIndexingMapOfProducerOperandsInFusedOp( - producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps); - - // Append the indexing maps for the remaining consumer operands. - fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1), - consumerIndexMaps.end()); + // Compute the fused operands list and indexing maps. + SmallVector fusedOperands; + SmallVector fusedIndexMaps; + fusedOperands.reserve(producer->getNumOperands() + + consumer->getNumOperands()); + fusedIndexMaps.reserve(producer->getNumOperands() + + consumer->getNumOperands()); + // In the following, numbering matches that of `generateFusedTensorOpRegion`. + // 3. Consumer input operands/maps up to consumerIdx (exclusive). + llvm::append_range(fusedOperands, + consumer.getInputs().take_front(consumerIdx)); + llvm::append_range( + fusedIndexMaps, + ArrayRef{consumer.getInputIndexingMaps()}.take_front( + consumerIdx)); + // 4. Splice in producer's input operands/maps. + llvm::append_range(fusedOperands, producer.getInputs()); + assert(producer->getNumResults() == 1 && "expected single result producer"); + AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); + for (auto &inputOpOperand : producer.getInputOpOperands()) { + // Compute indexing maps for the producer args in the fused operation. + AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( + inputOpOperand, producerResultIndexMap, + consumer.getInputIndexingMap(consumerIdx)); + fusedIndexMaps.push_back(map); + } + // 4.b. Producer output operand/map that is fused needs to be passed if it is + // an "initTensor" (i.e. its value is actually read). + assert(producer->getNumResults() == 1 && "expected single result producer"); + if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) { + llvm::append_range(fusedOperands, producer.getOutputs().take_front()); + // Compute indexing maps for the producer args in the fused operation. + AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( + producer.getOutputOpOperands().front(), producerResultIndexMap, + consumer.getOutputIndexingMap(0)); + fusedIndexMaps.push_back(map); + } + // 5. Remaining consumer's input operands/maps (drop past index + // `consumerIdx`). + llvm::append_range(fusedOperands, + consumer.getInputs().drop_front(consumerIdx + 1)); + llvm::append_range( + fusedIndexMaps, + ArrayRef{consumer.getInputIndexingMaps()}.drop_front( + consumerIdx + 1)); + // 6. All of consumer's output operands (skip operands: added by the builder). + // llvm::append_range(fusedOperands, consumer.getOutputs()); + llvm::append_range(fusedIndexMaps, consumer.getOutputIndexingMaps()); + // 7. All of producer's output operands/maps except the one fused. + // TODO: allow fusion of multi-result producers. + assert(producer->getNumResults() == 1 && "expected single result producer"); // Generate the fused op. - LinalgOp fusedOp; + Operation *fusedOp; if (isa(producer.getOperation()) && isa(consumer.getOperation())) { - fusedOp = - rewriter - .create(consumer.getLoc(), consumer->getResultTypes(), - /*inputs=*/fusedOperands, - // TODO: handle outputs. - consumer.getOutputs(), - rewriter.getArrayAttr(fusedIndexMaps), - consumer.iterator_types(), - /*doc=*/nullptr, - /*library_call=*/nullptr, - /*sparse=*/nullptr) - .getOperation(); + fusedOp = rewriter.create( + consumer.getLoc(), consumer->getResultTypes(), + /*inputs=*/fusedOperands, + // TODO: handle outputs. + consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps), + consumer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr, + /*sparse=*/nullptr); } else { - fusedOp = - rewriter - .create( - consumer.getLoc(), consumer->getResultTypes(), - /*inputs=*/fusedOperands, - // TODO: handle outputs. - consumer.getOutputs(), rewriter.getArrayAttr(fusedIndexMaps), - consumer.iterator_types(), - /*doc=*/nullptr, - /*library_call=*/nullptr, - /*sparse=*/nullptr) - .getOperation(); + fusedOp = rewriter.create( + consumer.getLoc(), consumer->getResultTypes(), + /*inputs=*/fusedOperands, + // TODO: handle outputs. + consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps), + consumer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr, + /*sparse=*/nullptr); } // Construct an AffineMap from consumer loops to producer loops. // consumer loop -> tensor index AffineMap consumerResultIndexMap = consumer.getInputIndexingMap(consumerIdx); - // producer loop -> tensor index - AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); // tensor index -> producer loop AffineMap invProducerResultIndexMap = inversePermutation(producerResultIndexMap); @@ -297,9 +324,9 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand, AffineMap consumerToProducerLoopsMap = invProducerResultIndexMap.compose(consumerResultIndexMap); - generateFusedElementwiseOpRegion(rewriter, fusedOp.getOperation(), producer, - consumer, consumerToProducerLoopsMap, - consumerIdx, consumer.getNumLoops()); + generateFusedElementwiseOpRegion(rewriter, fusedOp, producer, consumer, + consumerToProducerLoopsMap, consumerIdx, + consumer.getNumLoops()); return SmallVector(fusedOp->getResults()); } diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir index 13109bd98c19..b0a006398c99 100644 --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -616,3 +616,39 @@ func @sigmoid_dynamic_dim(%0: tensor) -> tensor { } -> tensor return %2 : tensor } + +// ----- + +func private @compute1(%a: f64) -> f64 +func private @compute2(%a: f64, %b: i32) -> i32 + +// CHECK-LABEL: func @generic_index_op2( +func @generic_index_op2(%arg0: tensor<1x8xf64>, %arg1: tensor<1x8xi32>) -> tensor<1x8xi32> { + %0 = linalg.generic { + indexing_maps = [affine_map<(i, j) -> (i, j)>], + iterator_types = ["parallel", "parallel"]} + outs(%arg0 : tensor<1x8xf64>) { + ^bb0(%a: f64): + %r = call @compute1(%a) : (f64) -> f64 + linalg.yield %r : f64 + } -> tensor<1x8xf64> + + // CHECK-NEXT: %[[R:.*]] = linalg.generic + // CHECK: bb0(%[[BBA:[0-9a-z]*]]: f64, %[[BBB:[0-9a-z]*]]: i32): + // CHECK-NEXT: %[[A:.*]] = call @compute1(%[[BBA]]) : (f64) -> f64 + // CHECK-NEXT: %[[B:.*]] = call @compute2(%[[A]], %[[BBB]]) : (f64, i32) -> i32 + // CHECK-NEXT: linalg.yield %[[B]] : i32 + // CHECK-NEXT: } -> tensor<1x8xi32> + %1 = linalg.generic { + indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>], + iterator_types = ["parallel", "parallel"]} + ins(%0 : tensor<1x8xf64>) + outs(%arg1 : tensor<1x8xi32>) { + ^bb0(%a: f64, %b: i32): + %r = call @compute2(%a, %b) : (f64, i32) -> i32 + linalg.yield %r : i32 + } -> tensor<1x8xi32> + + // CHECK-NEXT: return %[[R]] : tensor<1x8xi32> + return %1 : tensor<1x8xi32> +}