forked from OSchip/llvm-project
[mlir][scf] Add scf.for + tensor.cast canonicalization pattern
Fold scf.for iter_arg/result pairs that go through incoming/ougoing a tensor.cast op pair so as to pull the tensor.cast inside the scf.for: ``` %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32> %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor<?x?xf32>) { %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32> scf.yield %2 : tensor<?x?xf32> } %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32> use_of(%2) ``` folds into: ``` %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0) -> (tensor<32x1024xf32>) { %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32> %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32> %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32> scf.yield %4 : tensor<32x1024xf32> } use_of(%0) ``` Differential Revision: https://reviews.llvm.org/D100661
This commit is contained in:
parent
244d9d6e41
commit
843f1fc825
|
@ -9,6 +9,7 @@
|
|||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
|
@ -578,6 +579,140 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Perform a replacement of one iter OpOperand of an scf.for to the
|
||||
/// `replacement` value which is expected to be the source of a tensor.cast.
|
||||
/// tensor.cast ops are inserted inside the block to account for the type cast.
|
||||
static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
|
||||
OpOperand &operand,
|
||||
Value replacement) {
|
||||
Type oldType = operand.get().getType(), newType = replacement.getType();
|
||||
assert(oldType.isa<RankedTensorType>() && newType.isa<RankedTensorType>() &&
|
||||
"expected ranked tensor types");
|
||||
|
||||
// 1. Create new iter operands, exactly 1 is replaced.
|
||||
ForOp forOp = cast<ForOp>(operand.getOwner());
|
||||
assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
|
||||
"expected an iter OpOperand");
|
||||
if (operand.get().getType() == replacement.getType())
|
||||
return forOp;
|
||||
SmallVector<Value> newIterOperands;
|
||||
for (OpOperand &opOperand : forOp.getIterOpOperands()) {
|
||||
if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
|
||||
newIterOperands.push_back(replacement);
|
||||
continue;
|
||||
}
|
||||
newIterOperands.push_back(opOperand.get());
|
||||
}
|
||||
|
||||
// 2. Create the new forOp shell.
|
||||
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
|
||||
forOp.getLoc(), forOp.lowerBound(), forOp.upperBound(), forOp.step(),
|
||||
newIterOperands);
|
||||
Block &newBlock = newForOp.region().front();
|
||||
SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
|
||||
newBlock.getArguments().end());
|
||||
|
||||
// 3. Inject an incoming cast op at the beginning of the block for the bbArg
|
||||
// corresponding to the `replacement` value.
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPoint(&newBlock, newBlock.begin());
|
||||
BlockArgument newRegionIterArg = newForOp.getRegionIterArgForOpOperand(
|
||||
newForOp->getOpOperand(operand.getOperandNumber()));
|
||||
Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
|
||||
newRegionIterArg);
|
||||
newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
|
||||
|
||||
// 4. Steal the old block ops, mapping to the newBlockTransferArgs.
|
||||
Block &oldBlock = forOp.region().front();
|
||||
rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
|
||||
|
||||
// 5. Inject an outgoing cast op at the end of the block and yield it instead.
|
||||
auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
|
||||
rewriter.setInsertionPoint(clonedYieldOp);
|
||||
unsigned yieldIdx =
|
||||
newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
|
||||
Value castOut = rewriter.create<tensor::CastOp>(
|
||||
newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
|
||||
SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
|
||||
newYieldOperands[yieldIdx] = castOut;
|
||||
rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
|
||||
rewriter.eraseOp(clonedYieldOp);
|
||||
|
||||
// 6. Inject an outgoing cast op after the forOp.
|
||||
rewriter.setInsertionPointAfter(newForOp);
|
||||
SmallVector<Value> newResults = newForOp.getResults();
|
||||
newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
|
||||
newForOp.getLoc(), oldType, newResults[yieldIdx]);
|
||||
|
||||
return newForOp;
|
||||
}
|
||||
|
||||
/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
|
||||
/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
|
||||
///
|
||||
/// ```
|
||||
/// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
|
||||
/// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
|
||||
/// -> (tensor<?x?xf32>) {
|
||||
/// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
/// scf.yield %2 : tensor<?x?xf32>
|
||||
/// }
|
||||
/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
|
||||
/// use_of(%2)
|
||||
/// ```
|
||||
///
|
||||
/// folds into:
|
||||
///
|
||||
/// ```
|
||||
/// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
|
||||
/// -> (tensor<32x1024xf32>) {
|
||||
/// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
|
||||
/// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
/// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
|
||||
/// scf.yield %4 : tensor<32x1024xf32>
|
||||
/// }
|
||||
/// use_of(%0)
|
||||
/// ```
|
||||
struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
|
||||
using OpRewritePattern<ForOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ForOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
|
||||
OpOperand &iterOpOperand = std::get<0>(it);
|
||||
auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
|
||||
if (!incomingCast)
|
||||
continue;
|
||||
if (!std::get<1>(it).hasOneUse())
|
||||
continue;
|
||||
auto outgoingCastOp =
|
||||
dyn_cast<tensor::CastOp>(*std::get<1>(it).user_begin());
|
||||
if (!outgoingCastOp)
|
||||
continue;
|
||||
|
||||
// Must be a tensor.cast op pair with matching types.
|
||||
if (outgoingCastOp.getResult().getType() !=
|
||||
incomingCast.source().getType())
|
||||
continue;
|
||||
|
||||
// Create a new ForOp with that iter operand replaced.
|
||||
auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
|
||||
incomingCast.source());
|
||||
|
||||
// Insert outgoing cast and use it to replace the corresponding result.
|
||||
rewriter.setInsertionPointAfter(newForOp);
|
||||
SmallVector<Value> replacements = newForOp.getResults();
|
||||
unsigned returnIdx =
|
||||
iterOpOperand.getOperandNumber() - op.getNumControlOperands();
|
||||
replacements[returnIdx] = rewriter.create<tensor::CastOp>(
|
||||
op.getLoc(), incomingCast.dest().getType(), replacements[returnIdx]);
|
||||
rewriter.replaceOp(op, replacements);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
/// Canonicalize the iter_args of an scf::ForOp that involve a tensor_load and
|
||||
/// for which only the last loop iteration is actually visible outside of the
|
||||
/// loop. The canonicalization looks for a pattern such as:
|
||||
|
@ -706,7 +841,7 @@ struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
|
|||
void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<ForOpIterArgsFolder, SimplifyTrivialLoops,
|
||||
LastTensorLoadCanonicalization>(context);
|
||||
LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -580,3 +580,33 @@ func @fold_away_iter_and_result_with_no_use(%arg0 : i32,
|
|||
// CHECK: return %[[FOR_RES]] : i32
|
||||
return %0#0 : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func private @do(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
|
||||
// CHECK-LABEL: matmul_on_tensors
|
||||
// CHECK-SAME: %[[T0:[0-9a-z]*]]: tensor<32x1024xf32>
|
||||
// CHECK-SAME: %[[T1:[0-9a-z]*]]: tensor<1024x1024xf32>
|
||||
func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%c32 = constant 32 : index
|
||||
%c1024 = constant 1024 : index
|
||||
// CHECK-NOT: tensor.cast
|
||||
// CHECK: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args(%[[ITER_T0:.*]] = %[[T0]]) -> (tensor<32x1024xf32>) {
|
||||
// CHECK: %[[CAST:.*]] = tensor.cast %[[ITER_T0]] : tensor<32x1024xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[DONE:.*]] = call @do(%[[CAST]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[UNCAST:.*]] = tensor.cast %[[DONE]] : tensor<?x?xf32> to tensor<32x1024xf32>
|
||||
// CHECK: scf.yield %[[UNCAST]] : tensor<32x1024xf32>
|
||||
%0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
|
||||
%1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor<?x?xf32>) {
|
||||
%2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
scf.yield %2 : tensor<?x?xf32>
|
||||
}
|
||||
// CHECK-NOT: tensor.cast
|
||||
// CHECK: %[[RES:.*]] = subtensor_insert %[[FOR_RES]] into %[[T1]][0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32>
|
||||
// CHECK: return %[[RES]] : tensor<1024x1024xf32>
|
||||
%2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
|
||||
%res = subtensor_insert %2 into %t1[0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32>
|
||||
return %res : tensor<1024x1024xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue